Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Created October 25, 2023 08:55
Show Gist options
  • Save jweinst1/8bba643ae40a028e73b9aae09df1877a to your computer and use it in GitHub Desktop.
Save jweinst1/8bba643ae40a028e73b9aae09df1877a to your computer and use it in GitHub Desktop.
polling based locking server and client in C.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <signal.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <netdb.h>
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <unistd.h>
#include <pthread.h>
#include <poll.h>
static const char* local_host_ipv4 = "127.0.0.1";
static unsigned short port_nums[] = {12001, 12002};
// int inet_res = inet_aton(g_tcp_host, &(servaddr.sin_addr));
static int create_server_socket(unsigned short portno, const char* host, int blocking) {
struct sockaddr_in servaddr;
int server_fd = -1;
server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (server_fd == -1) {
return server_fd;
}
if (!blocking) {
int flags = fcntl(server_fd, F_GETFL, 0);
if (fcntl(server_fd, F_SETFL, flags | O_NONBLOCK)) {
fprintf(stderr, "Cannot set non blocking on socket\n");
close(server_fd);
return -1;
}
}
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
int inet_res = inet_aton(host, &(servaddr.sin_addr));
if (inet_res == 0) {
fprintf(stderr, "Cannot convert host to ip!\n");
close(server_fd);
return -1;
}
servaddr.sin_port = htons(portno);
if ((bind(server_fd, (struct sockaddr*)&servaddr, sizeof(servaddr))) != 0) {
fprintf(stderr, "Cannot bind socket\n");
close(server_fd);
return -1;
}
if ((listen(server_fd, 10)) != 0) {
fprintf(stderr, "Cannot listen to socket\n");
close(server_fd);
return -1;
}
return server_fd;
}
static int create_client_socket(unsigned short portno, const char* host, int blocking) {
struct sockaddr_in servaddr;
int server_fd = -1;
server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (server_fd == -1) {
return server_fd;
}
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
int inet_res = inet_aton(host, &(servaddr.sin_addr));
if (inet_res == 0) {
fprintf(stderr, "Cannot convert host to ip!\n");
close(server_fd);
return -1;
}
servaddr.sin_port = htons(portno);
errno=0;
while (connect(server_fd, (struct sockaddr*)&servaddr, sizeof(servaddr)) != 0) {
if (errno == EAGAIN || errno == EALREADY) {
errno = 0;
continue;
}
fprintf(stderr, "Failed to connect to host=%s, port=%u\n", host, portno);
perror("Error: ");
close(server_fd);
return -1;
}
if (!blocking) {
int flags = fcntl(server_fd, F_GETFL, 0);
if (fcntl(server_fd, F_SETFL, flags | O_NONBLOCK)) {
fprintf(stderr, "Cannot set non blocking on socket\n");
close(server_fd);
return -1;
}
}
return server_fd;
}
static int RET_GOOD = 0;
static int RET_BAD = 1;
static const unsigned char MESSAGE_LOCK = 1;
static const unsigned char MESSAGE_UNLOCK = 2;
static const unsigned char MESSAGE_LOCK_RESP = 3;
static const unsigned char MESSAGE_UNLOCK_RESP = 4;
static void server_write_request(int targetfd, const unsigned char op, const unsigned short val) {
uint64_t tid64;
pthread_threadid_np(NULL, &tid64);
const unsigned short flipped_val = htons(val);
printf("DEBUG tid=%llu write op %u and val %u\n", tid64, op, val);
write(targetfd, &op, sizeof(op));
write(targetfd, &flipped_val, sizeof(flipped_val));
}
static void server_read_request(int targetfd, unsigned char* op, unsigned short* val) {
unsigned short net_value = 0;
read(targetfd, op, sizeof(*op));
read(targetfd, &net_value, sizeof(net_value));
*val = ntohs(net_value);
}
static void* server_work(void* arg) {
int rand_sleep = 1500;
uint64_t tid64;
int my_server;
int core_conn;
struct pollfd pollers[10];
nfds_t poller_count = 1;
pthread_threadid_np(NULL, &tid64);
unsigned short* my_port = arg;
unsigned short locked_val = 0;
printf("tid=%llu, my port is %u\n", tid64, *my_port);
if (*my_port % 2 == 0) {
rand_sleep = 1803;
// create client connection
do {
sleep(1);
printf("tid=%llu, currently connecting\n", tid64);
} while ((core_conn = create_client_socket(*my_port - 1, local_host_ipv4, 0)) == -1);
} else {
rand_sleep = 506;
// create server connection
my_server = create_server_socket(*my_port, local_host_ipv4, 1);
if (my_server == -1) {
fprintf(stderr, "Thread id=%llu failed to create server socket!\n", tid64);
pthread_exit(&RET_BAD);
}
socklen_t clen = 0;
struct sockaddr_in cliaddr;
printf("tid=%llu, going to block on accept to see the other member\n", tid64);
int result = accept(my_server, (struct sockaddr*)&cliaddr, &clen);
int rflags = fcntl(result, F_GETFL, 0);
if (fcntl(result, F_SETFL, rflags | O_NONBLOCK)) {
fprintf(stderr, "tid=%llu Cannot set non blocking on socket\n", tid64);
close(result);
abort();
}
core_conn = result;
}
pollers[0].fd = core_conn;
pollers[0].events = POLLIN;
while(1) {
/*unsigned char byte_to_send = 0;
if (*my_port % 2 == 0) {
byte_to_send = 44;
} else {
byte_to_send = 22;
}
errno = 0;
write(pollers[0].fd, &byte_to_send, sizeof(byte_to_send));
if (errno == EAGAIN || errno == EWOULDBLOCK) {
fprintf(stderr, "tid=%llu, got blocked on a write!", tid64);
}*/
errno = 0;
int ready = poll(pollers, poller_count, rand_sleep);
if (ready == -1) {
fprintf(stderr, "Got bad return from poll, errno=%d\n", errno);
//pthread_exit(&RET_BAD);
} else if (ready == 0) {
printf("tid=%llu, No active request\n", tid64);
if (locked_val) {
// claimed val, check who it is.
if (locked_val == *my_port) {
// we currently have it, just release it.
printf("tid=%llu, going to write an unlock on %u\n", tid64, *my_port);
server_write_request(pollers[0].fd, MESSAGE_UNLOCK, *my_port);
} else {
// the other member has it, just wait to see when they release it
printf("tid=%llu, waiting for a release\n", tid64);
}
} else {
// unclaimed val and no active requests
printf("tid=%llu, locked_val is 0, going to write a lock on %u\n", tid64, *my_port);
server_write_request(pollers[0].fd, MESSAGE_LOCK, *my_port);
}
} else {
if (pollers[0].revents & POLLIN) {
unsigned char read_op = 0;
unsigned short read_val = 0;
errno = 0;
server_read_request(pollers[0].fd, &read_op, &read_val);
if (read_op == MESSAGE_LOCK) {
if (locked_val) {
if (locked_val == read_val) {
printf("tid=%llu, already set to requested value of %u\n", tid64, read_val);
} else {
printf("tid=%llu, lock is already acquired by value %u\n", tid64, locked_val);
}
server_write_request(pollers[0].fd, MESSAGE_LOCK_RESP, locked_val);
} else {
printf("tid=%llu, GOT LOCK setting locked val to %u\n", tid64, read_val);
locked_val = read_val;
server_write_request(pollers[0].fd, MESSAGE_LOCK_RESP, locked_val);
}
} else if (read_op == MESSAGE_UNLOCK) {
if (locked_val) {
if (locked_val == read_val) {
printf("tid=%llu, got unlock request for value %u\n", tid64, read_val);
locked_val = 0;
server_write_request(pollers[0].fd, MESSAGE_UNLOCK_RESP, read_val);
} else {
printf("tid=%llu, got unlock request but failed since not equal to lock=%u", tid64, locked_val);
server_write_request(pollers[0].fd, MESSAGE_UNLOCK_RESP, locked_val);
}
} else {
fprintf(stderr, "tid=%llu, ERROR, got unlock request despite being unlocked\n", tid64);
server_write_request(pollers[0].fd, MESSAGE_UNLOCK_RESP, locked_val);
}
} else if (read_op == MESSAGE_LOCK_RESP) {
if (read_val == *my_port) {
// lock on opposite node works, lets try to claim locally now
if (locked_val == 0) {
locked_val = *my_port;
printf("tid=%llu, claimed lock locally for val=%u after claiming remotely\n", tid64, locked_val);
} else {
printf("tid=%llu TIE detected, local val=%u, remote_val=%u\n", tid64, locked_val, read_val);
server_write_request(pollers[0].fd, MESSAGE_UNLOCK, *my_port);
}
} else {
printf("tid=%llu Failed to lock on the value %u\n", tid64, *my_port);
}
} else if (read_op == MESSAGE_UNLOCK_RESP) {
if (read_val) {
if (read_val == *my_port) {
printf("tid=%llu unlock for value %u worked\n", tid64, read_val);
assert(locked_val == *my_port);
locked_val = 0;
} else {
printf("tid=%llu unlock for value %u failed, remote is %u\n", tid64, *my_port, read_val);
}
} else {
fprintf(stderr, "tid=%llu unexpected failure on unlocking already released value\n", tid64);
}
} else {
fprintf(stderr, "tid=%llu unexpected op %u, errno=%d\n", tid64, read_op, errno);
abort();
}
}
}
}
if (*my_port % 2 != 0) {
close(my_server);
}
close(core_conn);
pthread_exit(&RET_GOOD);
}
/*
DEBUG tid=6976487 write op 4 and val 12001
tid=6976486 unlock for value 12001 worked
tid=6976486, No active request
tid=6976486, locked_val is 0, going to write a lock on 12001
DEBUG tid=6976486 write op 1 and val 12001
tid=6976487, GOT LOCK setting locked val to 12001
DEBUG tid=6976487 write op 3 and val 12001
tid=6976486, claimed lock locally for val=12001 after claiming remotely
tid=6976486, No active request
tid=6976486, going to write an unlock on 12001
DEBUG tid=6976486 write op 2 and val 12001
tid=6976487, got unlock request for value 12001
DEBUG tid=6976487 write op 4 and val 12001
tid=6976486 unlock for value 12001 worked
tid=6976486, No active request
tid=6976486, locked_val is 0, going to write a lock on 12001
DEBUG tid=6976486 write op 1 and val 12001
tid=6976487, GOT LOCK setting locked val to 12001
DEBUG tid=6976487 write op 3 and val 12001
tid=6976486, claimed lock locally for val=12001 after claiming remotely
tid=6976486, No active request
tid=6976486, going to write an unlock on 12001
DEBUG tid=6976486 write op 2 and val 12001
tid=6976487, got unlock request for value 12001
DEBUG tid=6976487 write op 4 and val 12001
tid=6976486 unlock for value 12001 worked
tid=6976486, No active request
tid=6976486, locked_val is 0, going to write a lock on 12001
DEBUG tid=6976486 write op 1 and val 12001
tid=6976487, GOT LOCK setting locked val to 12001
DEBUG tid=6976487 write op 3 and val 12001
tid=6976486, claimed lock locally for val=12001 after claiming remotely
*/
int main(int argc, char const *argv[])
{
//int* foo = NULL;
//pthread_join(id, (void**)&foo);
signal(SIGPIPE, SIG_IGN);
pthread_t id1;
pthread_t id2;
int s1 = pthread_create(&id1, NULL, &server_work, port_nums);
int s2 = pthread_create(&id2, NULL, &server_work, port_nums + 1);
if(s1 || s2) {
fprintf(stderr, "failed to create thread!!!\n");
abort();
}
int* exit_stat = NULL;
pthread_join(id1, (void**)&exit_stat);
pthread_join(id2, (void**)&exit_stat);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment