Skip to content

Instantly share code, notes, and snippets.

@gamemann
Last active December 22, 2023 06:58
Show Gist options
  • Save gamemann/381bf5dbdc6edc62d158f3d0e49abef5 to your computer and use it in GitHub Desktop.
Save gamemann/381bf5dbdc6edc62d158f3d0e49abef5 to your computer and use it in GitHub Desktop.
A test server that receives ethernet IP packets and uses epoll() to check when the socket is ready to read. Also supports multi-threading via `pthreads`.
/* Testing forwarding packets to destination. */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/ip.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <linux/if_packet.h>
#include <linux/tcp.h>
#include <linux/udp.h>
#include <linux/icmp.h>
#include <net/ethernet.h>
#include <net/if.h>
#include <string.h>
#include <error.h>
#include <errno.h>
#include <inttypes.h>
#include <pthread.h>
#include <sys/epoll.h>
#include <signal.h>
#define MAX_PCKT_LENGTH 65507
#define MAX_EPOLL_EVENTS 64
extern int errno;
int cont = 1;
int packetCount = 0;
struct stuff
{
char *lIP;
uint16_t lPort;
char *dIP;
uint16_t dPort;
uint8_t sockfd;
char *interface;
};
void sigHndl(int tmp)
{
cont = 0;
}
void* threadHndl(void * data)
{
struct stuff *stuff = data;
// Get thread ID.
uint8_t threadID = getpid();
// Headers.
char buffer[MAX_PCKT_LENGTH];
struct ethhdr *ethhdr = (struct ethhdr *) buffer;
struct iphdr *iphdr = (struct iphdr *) (buffer + sizeof(struct ethhdr));
struct sockaddr_ll din;
struct sockaddr_ll rin;
socklen_t dinLen = sizeof(din);
uint8_t epoll;
epoll = epoll_create1(0);
if (epoll < 0)
{
fprintf(stderr, "EPoll() :: Error Creating - %s\n", strerror(errno));
pthread_exit(NULL);
}
struct epoll_event event;
event.events = EPOLLIN;
event.data.fd = stuff->sockfd;
if (epoll_ctl(epoll, EPOLL_CTL_ADD, stuff->sockfd, &event) < 0)
{
fprintf(stderr, "EPoll_ctl() :: Error - %s\n", strerror(errno));
pthread_exit(NULL);
}
while (cont)
{
struct epoll_event event;
uint8_t nfd = epoll_wait(epoll, &event, 1, -1);
if (event.data.fd == stuff->sockfd)
{
uint16_t received = recvfrom(event.data.fd, &buffer, MAX_PCKT_LENGTH, 0, (struct sockaddr *)&din, &dinLen);
if (iphdr->protocol == IPPROTO_TCP)
{
// Handle TCP packets.
struct tcphdr *tcphdr = (struct tcphdr *) (buffer + sizeof(struct ethhdr) + sizeof(struct iphdr));
char *data = (char *) (buffer + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct tcphdr));
char *ip;
struct sockaddr_in test;
test.sin_addr.s_addr = iphdr->daddr;
ip = inet_ntoa(test.sin_addr);
fprintf(stdout, "Got TCP packet with %" PRIu16 " total length and %lu payload length. Destination address is %s. On thread %d.\n", received, strlen(data), ip, threadID);
packetCount++;
}
else if (iphdr->protocol == IPPROTO_UDP)
{
// Handle UDP packets.
struct udphdr *udphdr = (struct udphdr *) (buffer + sizeof(struct ethhdr) + sizeof(struct iphdr));
char *data = (char *) (buffer + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr));
char *ip;
struct sockaddr_in test;
test.sin_addr.s_addr = iphdr->daddr;
ip = inet_ntoa(test.sin_addr);
fprintf(stdout, "Got UDP packet with %" PRIu16 " total length and %" PRIu16 " payload length. Destination address is %s. On thread %d.\n", received, udphdr->len, ip, threadID);
packetCount++;
}
else if (iphdr->protocol == IPPROTO_ICMP)
{
// Handle ICMP packets.
struct icmphdr *icmphdr = (struct icmphdr *) (buffer + sizeof(struct ethhdr) + sizeof(struct iphdr));
char *data = (char *) (buffer + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct icmphdr));
char *ip;
struct sockaddr_in test;
test.sin_addr.s_addr = iphdr->daddr;
ip = inet_ntoa(test.sin_addr);
fprintf(stdout, "Got ICMP packet with %" PRIu16 " total length. Destination address is %s. On thread %d.\n", received, ip, threadID);
packetCount++;
}
else
{
continue;
}
}
}
close(stuff->sockfd);
pthread_exit(NULL);
}
int main(uint8_t argc, char *argv[])
{
if (argc < 6)
{
fprintf(stderr, "Usage: %s <Listen IP> <Listen Port> <Destination IP> <Destination Port> <Interface> [<Threads>]\n", argv[0]);
exit(1);
}
uint64_t startingTime = time(NULL);
uint8_t threads = 1;
if (argc > 6)
{
threads = atoi(argv[6]);
}
uint8_t sockfd;
sockfd = socket(AF_PACKET, SOCK_RAW, htons(ETH_P_IP));
if (sockfd == -1)
{
fprintf(stderr, "Socket() :: Error - %s\n", strerror(errno));
perror("socket");
exit(1);
}
if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, argv[5], strlen(argv[5])) < 0)
{
fprintf(stderr, "SetSockOpt() :: Error %s\n", strerror(errno));
perror("setsockopt");
exit(1);
}
fprintf(stdout, "Binding to %s:%u and redirecting to %s:%u with %" PRIu8 " threads and on interface %s.\n\n", argv[1], atoi(argv[2]), argv[3], atoi(argv[4]), threads, argv[5]);
struct stuff stuff;
stuff.lIP = argv[1];
stuff.lPort = atoi(argv[2]);
stuff.dIP = argv[3];
stuff.dPort = atoi(argv[4]);
stuff.interface = argv[5];
stuff.sockfd = sockfd;
for (uint8_t i = 0; i < threads; i++)
{
fprintf(stdout, "Starting thread #%" PRIu8 "\n", i);
// Create new thread.
pthread_t pid;
pthread_create(&pid, NULL, threadHndl, (void *)&stuff);
}
signal(SIGINT, sigHndl);
while(cont)
{
// Allow the program to stay up until signaled to shutdown.
}
uint64_t stoppingTime = time(NULL);
uint64_t timeE = stoppingTime - startingTime;
fprintf(stdout, "%d packets in %" PRIu64 " seconds\n\n\n", packetCount, timeE);
exit(0);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment