Skip to content

Instantly share code, notes, and snippets.

@sbernard31
Last active March 11, 2023 18:04
Show Gist options
  • Star 13 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save sbernard31/d4fee7518a1ff130452211c0d355b3f7 to your computer and use it in GitHub Desktop.
Save sbernard31/d4fee7518a1ff130452211c0d355b3f7 to your computer and use it in GitHub Desktop.
UDP load balancer proto using bcc (XDP/Bpf)
#define KBUILD_MODNAME "foo"
#include <uapi/linux/bpf.h>
#include <linux/bpf.h>
#include <linux/icmp.h>
#include <linux/if_ether.h>
#include <linux/if_vlan.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/udp.h>
/* 0x3FFF mask to check for fragment offset field */
#define IP_FRAGMENTED 65343
// MAC address
typedef unsigned char mac[6];
// Real Server structure (MAC address + IP address)
struct server {
__be32 ipAddr;
unsigned char macAddr[ETH_ALEN];
};
// packet structure to log load balancing
struct packet {
unsigned char dmac[ETH_ALEN];
unsigned char smac[ETH_ALEN];
__be32 daddr;
__be32 saddr;
};
BPF_PERF_OUTPUT(events);
__attribute__((__always_inline__))
static inline __u16 csum_fold_helper(__u64 csum) {
int i;
#pragma unroll
for (i = 0; i < 4; i ++) {
if (csum >> 16)
csum = (csum & 0xffff) + (csum >> 16);
}
return ~csum;
}
__attribute__((__always_inline__))
static inline void ipv4_csum_inline(void *iph, __u64 *csum) {
__u16 *next_iph_u16 = (__u16 *)iph;
#pragma clang loop unroll(full)
for (int i = 0; i < sizeof(struct iphdr) >> 1; i++) {
*csum += *next_iph_u16++;
}
*csum = csum_fold_helper(*csum);
}
__attribute__((__always_inline__))
static inline void ipv4_csum(void *data_start, int data_size, __u64 *csum) {
*csum = bpf_csum_diff(0, 0, data_start, data_size, *csum);
*csum = csum_fold_helper(*csum);
}
__attribute__((__always_inline__))
static inline void ipv4_l4_csum(void *data_start, __u32 data_size,
__u64 *csum, struct iphdr *iph) {
__u32 tmp = 0;
*csum = bpf_csum_diff(0, 0, &iph->saddr, sizeof(__be32), *csum);
*csum = bpf_csum_diff(0, 0, &iph->daddr, sizeof(__be32), *csum);
tmp = __builtin_bswap32((__u32)(iph->protocol));
*csum = bpf_csum_diff(0, 0, &tmp, sizeof(__u32), *csum);
tmp = __builtin_bswap32((__u32)(data_size));
*csum = bpf_csum_diff(0, 0, &tmp, sizeof(__u32), *csum);
*csum = bpf_csum_diff(0, 0, data_start, data_size, *csum);
*csum = csum_fold_helper(*csum);
}
// A map which contains port to redirect
BPF_HASH(ports, __be16, int, 10); // TODO how to we handle the max number of port we support.
// A map which contains real server
BPF_HASH(realServers, int, struct server, 10); // TODO how to we handle the max number of real server.
// Virtual IP is accessible via the 'VIP' constant
int xdp_prog(struct CTXTYPE *ctx) {
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/include/uapi/linux/if_ether.h
struct ethhdr * eth = data;
if (eth + 1 > data_end)
return XDP_DROP;
// Handle only IP packets (v4?)
if (eth->h_proto != bpf_htons(ETH_P_IP)){
return XDP_PASS;
}
// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/include/uapi/linux/ip.h
struct iphdr *iph;
iph = eth + 1;
if (iph + 1 > data_end)
return XDP_DROP;
// Minimum valid header length value is 5.
// see (https://tools.ietf.org/html/rfc791#section-3.1)
if (iph->ihl < 5)
return XDP_DROP;
// IP header size is variable because of options field.
// see (https://tools.ietf.org/html/rfc791#section-3.1)
//if ((void *) iph + iph->ihl * 4 > data_end)
// return XDP_DROP;
// TODO support IP header with variable size
if (iph->ihl != 5)
return XDP_PASS;
// Do not support fragmented packets as L4 headers may be missing
if (iph->frag_off & IP_FRAGMENTED)
return XDP_PASS; // TODO should we support it ?
// We only handle UDP traffic
if (iph->protocol != IPPROTO_UDP) {
return XDP_PASS;
}
// https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/include/uapi/linux/udp.h
struct udphdr *udp;
//udp = (void *) iph + iph->ihl * 4;
udp = iph + 1;
if (udp + 1 > data_end)
return XDP_DROP;
//__u16 udp_len = bpf_ntohs(udp->len);
__u16 udp_len = 8;
if (udp_len < 8)
return XDP_DROP;
if (udp_len > 512) // TODO use a more approriate max value
return XDP_DROP;
if ((void *) udp + udp_len > data_end)
return XDP_DROP;
// Is it ingress traffic ? destination IP == VIP
if (iph->daddr == VIP) {
if (!ports.lookup(&(udp->dest))) {
return XDP_PASS;
} else {
// Log packet before
struct packet pkt = {};
memcpy(&pkt, data, sizeof(pkt)); // crappy
pkt.daddr = iph->daddr;
pkt.saddr = iph->saddr;
events.perf_submit(ctx,&pkt,sizeof(pkt));
// handle ingress traffic
// TODO support several real server
int i = 0;
struct server * server = realServers.lookup(&i);
if (server == NULL) {
return XDP_PASS;
}
memcpy(eth->h_dest, server->macAddr, 6);
iph->daddr = server->ipAddr;
}
} else
// Is it egress traffic ? source IP == VIP
if (iph->saddr == VIP) {
if (!ports.lookup(&(udp->source))) {
return XDP_PASS;
} else {
// Log packet before
struct packet pkt = {};
memcpy(&pkt, data, sizeof(pkt)); // crappy
pkt.daddr = iph->daddr;
pkt.saddr = iph->saddr;
events.perf_submit(ctx,&pkt,sizeof(pkt));
// handle egress traffic
// TODO support several real server
int i = 0;
struct server * server = realServers.lookup(&i);
if (server == NULL) {
return XDP_PASS;
}
memcpy(eth->h_source, server->macAddr, 6);
iph->saddr = server->ipAddr;
}
} else {
return XDP_PASS;
}
// Update IP checksum
// TODO support IP header with variable size
iph->check = 0;
__u64 cs = 0 ;
ipv4_csum(iph, sizeof (*iph), &cs);
iph->check = cs;
// Update UDP checksum
udp->check = 0;
cs = 0;
ipv4_l4_csum(udp, udp_len, &cs, iph) ;
udp->check = cs;
// Log packet after
struct packet pkt = {};
memcpy(&pkt, data, sizeof(pkt)); // crappy
pkt.daddr = iph->daddr;
pkt.saddr = iph->saddr;
events.perf_submit(ctx,&pkt,sizeof(pkt));
return XDP_TX;
}
#!/usr/bin/python
from __future__ import print_function
from bcc import BPF
import ctypes as ct
import ipaddress
import socket
import argparse
import binascii
import struct
import re
# Utils
def ip_strton(ip_address):
# struct.unpack("I", socket.inet_aton(ip_address))[0]
return socket.htonl((int) (ipaddress.ip_address(ip_address)))
def ip_ntostr(ip_address):
if isinstance(ip_address, ct.c_uint):
ip_address = ip_address.value
return ipaddress.ip_address(socket.ntohl(ip_address))
def mac_strtob(mac_address):
bytes = binascii.unhexlify(mac_address.replace(':',''))
if len(bytes) is not 6:
raise TypeError("mac address must be a 6 bytes arrays")
return bytes
def mac_btostr(mac_address):
bytestr = bytes(mac_address).hex()
return ':'.join(bytestr[i:i+2] for i in range(0,12,2))
def ip_mac_tostr(mac_address, ip_address):
return "{}/{}".format(mac_btostr(mac_address),ip_ntostr(ip_address))
# Custom argument parser
def mac_ip_parser(s,pat=re.compile("^(.+?)/(.+)$")):
m = pat.match(s)
if not m:
raise argparse.ArgumentTypeError("Invalid address '{}': format is 'MAC_addr/IP_addr' (e.g. 5E:FF:56:A2:AF:15/10.40.0.1)".format(s))
try:
mac = mac_strtob(m.group(1))
except Exception as e:
raise argparse.ArgumentTypeError("Invalid MAC address '{}' : {}".format(m.group(1), str(e)))
try:
ip = ip_strton(m.group(2))
except Exception as e:
raise argparse.ArgumentTypeError("Invalid IP address '{}' : {}".format(m.group(2), str(e)))
return {"ip":ip,"mac":mac}
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument("ifnet", help="network interface to load balance (e.g. eth0)")
parser.add_argument("-vip", "--virtual_ip", help="<Required> The virtual IP of this loadbalancer", required=True)
parser.add_argument("-rs", "--real_server",type=mac_ip_parser, nargs=1, help="<Required> Real server addresse(s) e.g. 5E:FF:56:A2:AF:15/10.40.0.1", required=True)
parser.add_argument("-p", "--port", type=int, nargs='+', help="<Required> UDP port(s) to load balance", required=True)
parser.add_argument("-d", "--debug", type=int, choices=[0, 1, 2, 3, 4],
help="Use to set bpf verbosity (0 is minimal)", default=0)
args = parser.parse_args()
# Get configuration from Arguments
ifnet = args.ifnet # network interface to attach xdp program
vip = ip_strton(args.virtual_ip) # virtual ip of load balancer
real_servers = args.real_server
ports = args.port # ports of to load balance
debug = args.debug # bpf verbosity
print("\nLoad balancing UDP traffic over {} interface for port(s) {} from :".format(ifnet, ports, ip_ntostr(vip)))
for real_server in real_servers:
print ("VIP:{} <=======> Real Server:{}".format(ip_ntostr(vip), ip_mac_tostr(real_server["mac"],real_server["ip"])))
# Shared structure used for perf_buffer
class Data(ct.Structure):
_fields_ = [
("dmac", ct.c_ubyte * 6),
("smac", ct.c_ubyte * 6),
("daddr", ct.c_uint),
("saddr", ct.c_uint)
]
# Compile & attach bpf program
b = BPF(src_file ="test.c", debug=debug, cflags=["-w", "-DVIP={}".format(vip), "-DCTXTYPE=xdp_md"])
fn = b.load_func("xdp_prog", BPF.XDP)
b.attach_xdp(ifnet, fn)
# Set Configurations
## Ports configs
ports_map = b["ports"]
for port in ports:
ports_map[ports_map.Key(socket.htons(port))] = ports_map.Leaf(True)
## Real servers configs
real_servers_map = b.get_table("realServers")
i = 0
for real_server in real_servers:
real_servers_map[real_servers_map.Key(i)] = real_servers_map.Leaf(real_server['ip'], (ct.c_ubyte * 6).from_buffer_copy(real_server['mac']))
i+=1
# Utility function to print udp dest NAT.
def print_event(cpu, data, size):
event = ct.cast(data, ct.POINTER(Data)).contents
print("source {} --> dest {}".format(ip_mac_tostr(event.smac, event.saddr),ip_mac_tostr(event.dmac, event.daddr)))
# Loop to read perf buffer
b["events"].open_perf_buffer(print_event)
while 1:
try:
b.perf_buffer_poll()
# DEBUG STUFF
#(task, pid, cpu, flags, ts, msg) = b.trace_fields()
#print("%s \n" % (msg))
except ValueError:
continue
except KeyboardInterrupt:
break;
# Detach bpf progam
b.remove_xdp(ifnet)
@sbernard31
Copy link
Author

usage: test.py [-h] -vip VIRTUAL_IP -rs REAL_SERVER -p PORT [PORT ...]
               [-d {0,1,2,3,4}]
               ifnet

positional arguments:
  ifnet                 network interface to load balance (e.g. eth0)

optional arguments:
  -h, --help            show this help message and exit
  -vip VIRTUAL_IP, --virtual_ip VIRTUAL_IP
                        <Required> The virtual IP of this loadbalancer
  -rs REAL_SERVER, --real_server REAL_SERVER
                        <Required> Real server addresse(s) e.g.
                        5E:FF:56:A2:AF:15/10.40.0.1
  -p PORT [PORT ...], --port PORT [PORT ...]
                        <Required> UDP port(s) to load balance
  -d {0,1,2,3,4}, --debug {0,1,2,3,4}
                        Use to set bpf verbosity (0 is minimal)

Eg : sudo python3 test.py lo -vip 10.41.44.13 -rs 00:00:00:00:00:00/127.0.0.1 -p 5683 5684

@sbernard31
Copy link
Author

See a more advanced version of the code at : https://github.com/AirVantage/sbulb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment