Skip to content

Instantly share code, notes, and snippets.

@rlapz
Last active April 4, 2023 19:43
Show Gist options
  • Save rlapz/5776e86bc6715a1bb5c7cef9dc1c9fcc to your computer and use it in GitHub Desktop.
Save rlapz/5776e86bc6715a1bb5c7cef9dc1c9fcc to your computer and use it in GitHub Desktop.
liburing exercise
/* compile: cc fturing.c -o fturing -luring -DNDEBUG -O3 */
#include <endian.h>
#include <errno.h>
#include <error.h>
#include <liburing.h>
#include <libgen.h>
#include <netdb.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <unistd.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#ifdef NDEBUG
#define likely(X) __builtin_expect(!!(X), 0)
#define unlikely(X) __builtin_expect(!!(X), 1)
#define HOT __attribute__((__hot__))
#define INLINE inline __attribute__((__always_inline__))
#define DPRINT(...)
#define PERROR(ERRNUM, MSG)\
fprintf(stderr, "Error: " MSG ": %s\n", strerror(abs(ERRNUM)))
#else
#define likely(X) (X)
#define unlikely(X) (X)
#define HOT
#define INLINE
#define DPRINT(...)\
printf(__VA_ARGS__)
#define PERROR(ERRNUM, MSG)\
error_at_line(0, abs(ERRNUM), __FILE__, __LINE__, \
"Error: %s: " MSG, __func__)
#endif
#define UPLOAD_DIR "upload_dir"
#define BUFFER_SIZE 8192
#define SERVER_QUEUE_DEPTH 32
#define SERVER_CLIENTS_MAX 1024
/*
* Packet
*/
typedef struct {
uint64_t fsize;
uint8_t fname_len;
char fname[255];
} __attribute__((__packed__)) Packet;
#define PACKET_SIZE (sizeof(Packet))
static inline void packet_fname_set(Packet *self, const char fname[],
uint8_t len);
static inline const char *packet_fname_get(Packet *self);
static inline void packet_fsize_set(Packet *self, uint64_t fsize);
static inline uint64_t packet_fsize_get(const Packet *self);
static inline int packet_check(const Packet *self);
/*
* MemPool
*/
typedef struct _Node {
struct _Node *prev;
} MemPoolNode;
typedef struct {
MemPoolNode node;
uint8_t ptr[];
} MemPoolItem;
typedef struct {
size_t chunks;
MemPoolItem *items;
} MemPool;
static MemPool mempool_init(size_t nmemb, size_t size);
static void mempool_deinit(MemPool *self);
static void *mempool_alloc(MemPool *self);
static void mempool_free(MemPool *self, void *mem);
/*
* Utils
*/
static int socket_server_create(const char addr[], const char port[]);
static int socket_client_create(const char addr[], const char port[]);
static void uring_prep_accept(struct io_uring *self, int sock_fd);
static void uring_prep_recv(struct io_uring *self, int sock_fd, uint8_t buffer[],
unsigned size, void *udata);
static void uring_prep_write(struct io_uring *self, int file_fd,
const uint8_t buffer[], unsigned size, __u64 offt,
void *udata);
static struct io_uring_sqe *uring_sqe_get(struct io_uring *self);
static int signal_set(void (*handler)(int sig));
/*
* SClient
*/
enum {
SCLIENT_STATE_PROP,
SCLIENT_STATE_RECV,
SCLIENT_STATE_WRITE,
SCLIENT_STATE_FINISH,
};
typedef struct {
int state;
int sock_fd;
int file_fd;
int __pad;
uint64_t fsize;
uint64_t bytes;
struct io_uring *uring;
union {
Packet pkt;
uint8_t raw[BUFFER_SIZE];
};
} SClient;
static void sclient_set(SClient *self, int sock_fd, struct io_uring *uring);
static void sclient_unset(SClient *self);
static INLINE int sclient_handle(SClient *self, int res);
static int sclient_handle_file_prop(SClient *self, int res);
static HOT int sclient_handle_file_recv(SClient *self, int res);
static HOT int sclient_handle_file_write(SClient *self, int res);
static int sclient_file_prep(SClient *self);
/*
* Server
*/
enum {
SERVER_STATE_CLIENT_ADD,
};
typedef struct {
int is_alive;
int sock_fd;
MemPool clients;
struct io_uring uring;
} Server;
static int server_init(Server *self);
static void server_deinit(Server *self);
static int server_run(Server *self, const char host[], const char port[]);
static INLINE int server_handle_events(Server *self);
static void server_stop(Server *self);
static void server_accept(Server *self, int sock_fd);
/*
* Client
*/
typedef struct client {
int is_alive;
int sock_fd;
int file_fd;
int __pad;
union {
Packet pkt;
char raw[BUFFER_SIZE];
};
} Client;
static void client_init(Client *self);
static void client_deinit(Client *self);
static int client_run(Client *self, const char addr[], const char port[],
char file[]);
static int client_prep_prop(Client *self, char file[]);
static int client_send_prop(Client *self);
static int client_send_file(Client *self);
static void client_stop(Client *self);
/****************************************************************************
* IMPL *
***************************************************************************/
/*
* Packet
*/
static inline void
packet_fname_set(Packet *self, const char fname[], uint8_t len)
{
memcpy(self->fname, fname, len);
self->fname[len] = '\0';
self->fname_len = len;
}
static inline const char *
packet_fname_get(Packet *self)
{
self->fname[self->fname_len] = '\0';
return self->fname;
}
static inline void
packet_fsize_set(Packet *self, uint64_t fsize)
{
self->fsize = htobe64(fsize);
}
static inline uint64_t
packet_fsize_get(const Packet *self)
{
return be64toh(self->fsize);
}
static inline int
packet_check(const Packet *self)
{
if (self->fname_len == 0 || strstr(self->fname, "..") != NULL)
return -EINVAL;
return 0;
}
/*
* MemPool
*/
static MemPool
mempool_init(size_t nmemb, size_t size)
{
MemPool new_mem = {
.chunks = sizeof(MemPoolItem) + nmemb,
.items = NULL,
};
for (size_t i = 0; i < size; i++) {
MemPoolItem *const new_item = malloc(new_mem.chunks);
if (new_item == NULL) {
mempool_deinit(&new_mem);
break;
}
if (new_mem.items != NULL)
new_item->node.prev = &new_mem.items->node;
else
new_item->node.prev = NULL;
new_mem.items = new_item;
}
return new_mem;
}
static void
mempool_deinit(MemPool *self)
{
MemPoolItem *tail = self->items;
while (tail != NULL) {
MemPoolNode *const prev = tail->node.prev;
free(tail);
tail = (MemPoolItem *)prev;
}
self->items = NULL;
}
static void *
mempool_alloc(MemPool *self)
{
MemPoolItem *mem = self->items;
if (mem != NULL) {
// use pre-allocated memory
self->items = (MemPoolItem *)mem->node.prev;
} else {
// allocate new memory
mem = malloc(self->chunks);
if (mem == NULL)
return NULL;
}
return mem->ptr;
}
static void
mempool_free(MemPool *self, void *mem)
{
void *const item = ((uint8_t *)mem) - sizeof(MemPoolItem);
if (self->items != NULL)
((MemPoolItem *)item)->node.prev = &self->items->node;
else
((MemPoolItem *)item)->node.prev = NULL;
// putting back the allocated memory
self->items = item;
DPRINT("%p: freed\n", mem);
}
/*
* Utils
*/
static int
socket_server_create(const char host[], const char port[])
{
int ret, sock_fd;
struct addrinfo *ai, *p;
const struct addrinfo addri = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
};
ret = getaddrinfo(host, port, &addri, &ai);
if (ret != 0) {
fprintf(stderr, "socket_server_create: getaddrinfo: %s\n",
gai_strerror(ret));
return -1;
}
for (p = ai; p != NULL; p = p->ai_next) {
sock_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (sock_fd < 0) {
PERROR(errno, "socket");
continue;
}
const int sock_opt = 1;
if (setsockopt(sock_fd, SOL_SOCKET, SO_REUSEADDR, &sock_opt,
sizeof(sock_opt)) < 0) {
PERROR(errno, "setsockopt");
goto err0;
}
if (bind(sock_fd, p->ai_addr, p->ai_addrlen) < 0) {
PERROR(errno, "bind");
goto err0;
}
if (listen(sock_fd, 32) < 0) {
PERROR(errno, "listen");
goto err0;
}
/* success */
break;
err0:
close(sock_fd);
}
freeaddrinfo(ai);
if (p == NULL) {
/* TODO */
return -1;
}
return sock_fd;
}
static int
socket_client_create(const char addr[], const char port[])
{
int ret, sock_fd;
struct addrinfo *ai, *p;
const struct addrinfo addri = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
};
ret = getaddrinfo(addr, port, &addri, &ai);
if (ret != 0) {
fprintf(stderr, "socket_client_create: getaddrinfo: %s\n",
gai_strerror(ret));
return -1;
}
for (p = ai; p != NULL; p = p->ai_next) {
sock_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (sock_fd < 0) {
PERROR(errno, "socket");
continue;
}
if (connect(sock_fd, p->ai_addr, p->ai_addrlen) < 0) {
PERROR(errno, "connect");
close(sock_fd);
continue;
}
/* success */
break;
}
freeaddrinfo(ai);
if (p == NULL) {
/* TODO */
return -1;
}
return sock_fd;
}
static void
uring_prep_accept(struct io_uring *self, int sock_fd)
{
struct io_uring_sqe *const sqe = uring_sqe_get(self);
io_uring_prep_accept(sqe, sock_fd, NULL, NULL, 0);
io_uring_sqe_set_data64(sqe, SERVER_STATE_CLIENT_ADD);
}
static void
uring_prep_recv(struct io_uring *self, int sock_fd, uint8_t buffer[],
unsigned size, void *udata)
{
struct io_uring_sqe *const sqe = uring_sqe_get(self);
io_uring_prep_recv(sqe, sock_fd, buffer, size, 0);
io_uring_sqe_set_data(sqe, udata);
}
static void
uring_prep_write(struct io_uring *self, int file_fd, const uint8_t buffer[],
unsigned size, __u64 offt, void *udata)
{
struct io_uring_sqe *const sqe = uring_sqe_get(self);
io_uring_prep_write(sqe, file_fd, buffer, size, offt);
io_uring_sqe_set_data(sqe, udata);
}
static struct io_uring_sqe *
uring_sqe_get(struct io_uring *self)
{
struct io_uring_sqe *sqe = io_uring_get_sqe(self);
if (unlikely(sqe == NULL)) {
io_uring_submit(self);
sqe = io_uring_get_sqe(self);
}
return sqe;
}
static int
signal_set(void (*handler)(int sig))
{
struct sigaction act = {
.sa_handler = SIG_IGN,
};
if (sigaction(SIGPIPE, &act, NULL) < 0) {
PERROR(errno, "sigaction: ~SIGPIPE");
return -1;
}
act.sa_handler = handler;
if (sigaction(SIGINT, &act, NULL) < 0) {
PERROR(errno, "sigaction: ~SIGINT");
return -1;
}
if (sigaction(SIGTERM, &act, NULL) < 0) {
PERROR(errno, "sigaction: ~SIGTERM");
return -1;
}
if (sigaction(SIGHUP, &act, NULL) < 0) {
PERROR(errno, "sigaction: ~SIGHUP");
return -1;
}
return 0;
}
/*
* SClient
*/
static void
sclient_set(SClient *self, int sock_fd, struct io_uring *uring)
{
self->state = SCLIENT_STATE_PROP;
self->sock_fd = sock_fd;
self->file_fd = -1;
self->fsize = 0;
self->bytes = 0;
self->uring = uring;
}
static void
sclient_unset(SClient *self)
{
if (self->sock_fd > 0)
close(self->sock_fd);
if (self->file_fd > 0)
close(self->file_fd);
}
static INLINE int
sclient_handle(SClient *self, int res)
{
if (unlikely(res <= 0))
goto out;
int state = self->state;
switch (state) {
case SCLIENT_STATE_WRITE:
state = sclient_handle_file_write(self, res);
break;
case SCLIENT_STATE_RECV:
state = sclient_handle_file_recv(self, res);
break;
case SCLIENT_STATE_PROP:
state = sclient_handle_file_prop(self, res);
break;
}
if (state == SCLIENT_STATE_FINISH)
goto out;
self->state = state;
return 0;
out:
sclient_unset(self);
printf("closed connection: %d\n", self->sock_fd);
return -1;
}
static int
sclient_handle_file_prop(SClient *self, int res)
{
DPRINT("packet_size: %d: %zu\n", res, PACKET_SIZE);
const uint64_t recvd = self->bytes + (uint64_t)res;
if (recvd < PACKET_SIZE) {
uring_prep_recv(self->uring, self->sock_fd, self->raw + recvd,
(unsigned)(PACKET_SIZE - recvd), self);
self->bytes = recvd;
return SCLIENT_STATE_PROP;
}
if (unlikely(recvd != PACKET_SIZE)) {
DPRINT("corrupted\n");
return SCLIENT_STATE_FINISH;
}
res = sclient_file_prep(self);
if (unlikely(res < 0))
return SCLIENT_STATE_FINISH;
uring_prep_recv(self->uring, self->sock_fd, self->raw, BUFFER_SIZE, self);
self->bytes = 0;
return SCLIENT_STATE_RECV;
}
static HOT int
sclient_handle_file_recv(SClient *self, int res)
{
if (unlikely(self->bytes >= self->fsize))
return SCLIENT_STATE_FINISH;
uring_prep_write(self->uring, self->file_fd, self->raw, (unsigned)res,
self->bytes, self);
return SCLIENT_STATE_WRITE;
}
static HOT int
sclient_handle_file_write(SClient *self, int res)
{
// record written bytes
self->bytes += (uint64_t)res;
uring_prep_recv(self->uring, self->sock_fd, self->raw, BUFFER_SIZE, self);
return SCLIENT_STATE_RECV;
}
static int
sclient_file_prep(SClient *self)
{
int ret;
char path[4096];
Packet *const pkt = &self->pkt;
printf("File name: %s: %" PRIu64 "\n", packet_fname_get(pkt),
packet_fsize_get(pkt));
ret = packet_check(pkt);
if (unlikely(ret < 0)) {
PERROR(errno, "packet_check");
return ret;
}
const char *const fname = packet_fname_get(pkt);
const size_t path_len = pkt->fname_len + sizeof(UPLOAD_DIR) +1;
ret = snprintf(path, path_len, "%s/%s", UPLOAD_DIR, fname);
if (unlikely(ret < 0)) {
PERROR(errno, "snprintf");
return ret;
}
path[ret] = '\0';
ret = open(path, O_TRUNC | O_CREAT | O_WRONLY, 0644);
if (unlikely(ret < 0)) {
PERROR(errno, "open: ~file_fd");
return ret;
}
self->fsize = packet_fsize_get(pkt);
self->file_fd = ret;
return ret;
}
/*
* Server
*/
static int
server_init(Server *self)
{
const int ret = io_uring_queue_init(SERVER_QUEUE_DEPTH, &self->uring, 0);
if (ret < 0) {
PERROR(ret, "io_uring_queue_init");
return ret;
}
self->is_alive = 0;
self->sock_fd = -1;
self->clients = mempool_init(sizeof(SClient), SERVER_CLIENTS_MAX);
return 0;
}
static void
server_deinit(Server *self)
{
io_uring_queue_exit(&self->uring);
mempool_deinit(&self->clients);
}
static int
server_run(Server *self, const char host[], const char port[])
{
int ret, sock_fd;
struct io_uring *const uring = &self->uring;
ret = socket_server_create(host, port);
if (ret < 0)
return ret;
sock_fd = ret;
self->sock_fd = sock_fd;
uring_prep_accept(uring, sock_fd);
self->is_alive = 1;
while (self->is_alive) {
ret = io_uring_submit_and_wait(uring, 1);
if (unlikely(ret < 0)) {
PERROR(ret, "io_uring_submit_and_wait");
break;
}
ret = server_handle_events(self);
if (unlikely(ret < 0))
break;
}
close(sock_fd);
return ret;
}
static INLINE int
server_handle_events(Server *self)
{
int ret = 0;
unsigned i = 0, head;
struct io_uring *const uring = &self->uring;
struct io_uring_cqe *cqe;
io_uring_for_each_cqe(uring, head, cqe) {
__u64 res = cqe->user_data;
switch (res) {
case SERVER_STATE_CLIENT_ADD:
server_accept(self, cqe->res);
uring_prep_accept(uring, self->sock_fd);
break;
default:
ret = sclient_handle((SClient *)res, cqe->res);
if (unlikely(ret < 0))
mempool_free(&self->clients, (SClient *)res);
ret = 0;
}
i++;
}
io_uring_cq_advance(uring, i);
return ret;
}
static void
server_stop(Server *self)
{
if (!self->is_alive) {
PERROR(EAGAIN, "~is_alive");
return;
}
const int fd = self->sock_fd;
if (fd > 0) {
if (shutdown(fd, SHUT_RDWR) < 0)
PERROR(errno, "shutdown");
}
self->is_alive = 0;
}
static void
server_accept(Server *self, int sock_fd)
{
if (unlikely(sock_fd < 0)) {
PERROR(sock_fd, "~sock_fd");
return;
}
SClient *const new_client = mempool_alloc(&self->clients);
if (unlikely(new_client == NULL)) {
PERROR(ENOMEM, "mempool_alloc");
goto err0;
}
sclient_set(new_client, sock_fd, &self->uring);
uring_prep_recv(&self->uring, sock_fd, new_client->raw, PACKET_SIZE,
new_client);
printf("new connection: %d\n", new_client->sock_fd);
return;
err0:
close(sock_fd);
}
/*
* Client
*/
static void
client_init(Client *self)
{
self->is_alive = 0;
self->sock_fd = -1;
self->file_fd = -1;
}
static void
client_deinit(Client *self)
{
if (self->sock_fd > 0)
close(self->sock_fd);
if (self->file_fd > 0)
close(self->file_fd);
}
static int
client_run(Client *self, const char addr[], const char port[], char file[])
{
int ret = socket_client_create(addr, port);
if (ret < 0)
return -1;
self->sock_fd = ret;
ret = client_prep_prop(self, file);
if (ret < 0)
return ret;
self->is_alive = true;
ret = client_send_prop(self);
if (ret < 0)
return ret;
ret = client_send_file(self);
if (ret < 0)
return ret;
return 0;
}
static int
client_prep_prop(Client *self, char file[])
{
int ret = 0;
int file_fd;
struct stat stat;
ret = open(file, O_RDONLY);
if (ret < 0) {
ret = -errno;
PERROR(ret, "open");
return ret;
}
file_fd = ret;
ret = fstat(ret, &stat);
if (ret < 0) {
ret = -errno;
PERROR(ret, "fstat");
goto out0;
}
const uint64_t fsize = (uint64_t)stat.st_size;
const char *const fname = basename(file);
memset(&self->pkt, 0, PACKET_SIZE);
packet_fname_set(&self->pkt, fname, (uint8_t)strlen(fname));
packet_fsize_set(&self->pkt, fsize);
ret = packet_check(&self->pkt);
if (ret < 0) {
PERROR(ret, "packet_check");
goto out0;
}
printf( "File name: %s\n"
"File size: %" PRIu64 "\n",
packet_fname_get(&self->pkt), packet_fsize_get(&self->pkt));
self->file_fd = file_fd;
return 0;
out0:
close(file_fd);
return ret;
}
static int
client_send_prop(Client *self)
{
char *buff = self->raw;
const int sock_fd = self->sock_fd;
size_t sent = 0;
while (likely(self->is_alive && sent < PACKET_SIZE)) {
const ssize_t s = send(sock_fd, buff, PACKET_SIZE - sent, 0);
if (unlikely(s < 0)) {
PERROR(errno, "send: file prop");
return -1;
}
if (unlikely(s == 0))
break;
sent += (size_t)s;
buff += sent;
}
if (sent != PACKET_SIZE) {
PERROR(EINVAL, "send: broken file prop");
return -1;
}
return 0;
}
static int
client_send_file(Client *self)
{
int ret;
const int file_fd = self->file_fd;
const int sock_fd = self->sock_fd;
const uint64_t fsize = packet_fsize_get(&self->pkt);
char *buff = self->raw;
uint64_t wrtn = 0;
while (likely(self->is_alive && wrtn < fsize)) {
const ssize_t r = read(file_fd, buff, BUFFER_SIZE);
if (unlikely(r < 0)) {
ret = errno;
PERROR(ret, "read");
return -ret;
}
if (r == 0)
break;
size_t rd = 0;
while (self->is_alive && rd < (size_t)r) {
const ssize_t s = send(sock_fd, buff + rd,
(size_t)r - rd, 0);
if (unlikely(s < 0)) {
ret = errno;
PERROR(ret, "send");
return -ret;
}
if (s == 0)
goto out;
rd += (size_t)s;
}
wrtn += (uint64_t)rd;
}
out:
if (wrtn != fsize) {
ret = EINVAL;
PERROR(ret, "corrupted file");
return -ret;
}
return 0;
}
static void
client_stop(Client *self)
{
if (!self->is_alive) {
PERROR(EAGAIN, "~is_alive");
return;
}
const int fd = self->sock_fd;
if (fd > 0) {
if (shutdown(fd, SHUT_RDWR) < 0)
PERROR(errno, "shutdown");
}
self->is_alive = 0;
}
/*
* Main
*/
static Server *server_g;
static Client *client_g;
static void
signal_handler_server(int sig)
{
fprintf(stderr, "\nInterrupted: %d\n", sig);
server_stop(server_g);
}
static void
signal_handler_client(int sig)
{
fprintf(stderr, "\nInterrupted: %d\n", sig);
client_stop(client_g);
}
static int
run_server(char *argv[])
{
Server server;
int ret = server_init(&server);
if (ret < 0)
return -1;
server_g = &server;
ret = signal_set(signal_handler_server);
if (ret < 0)
goto out0;
ret = server_run(&server, argv[0], argv[1]);
out0:
server_deinit(&server);
return ret;
}
static int
run_client(char *argv[])
{
Client client;
client_init(&client);
client_g = &client;
if (signal_set(signal_handler_client) < 0)
return -1;
const int ret = client_run(&client, argv[0], argv[1], argv[2]);
client_deinit(&client);
return ret;
}
int
main(int argc, char *argv[])
{
#ifndef NDEBUG
puts("Debug");
#endif
if (argc <= 2)
goto help;
if (strcmp(argv[1], "client") == 0) {
if (argc != 5)
goto help;
return -run_client(argv +2);
} else if (strcmp(argv[1], "server") == 0) {
if (argc != 4)
goto help;
return -run_server(argv +2);
}
help:
printf("client: %s client [server host] [server port] [file path]\n"
"server: %s server [listen host] [listern port]\n",
argv[0], argv[0]);
return EINVAL;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment