Skip to content

Instantly share code, notes, and snippets.

@mentha
Created May 3, 2022 13:30
Show Gist options
  • Save mentha/9d346a9b5967f745297861b6480fcc30 to your computer and use it in GitHub Desktop.
Save mentha/9d346a9b5967f745297861b6480fcc30 to your computer and use it in GitHub Desktop.
compress list of cidr
/* Compress CIDR list by aggregrating adjacent subnets. */
#define _POSIX_C_SOURCE 200809L
#include <arpa/inet.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
struct {
int max_length;
int quiet;
int truncate;
const char *progname;
int ipv6;
} opt;
#define ERRORF(fmt, ...) { \
if (!opt.quiet) \
fprintf(stderr, "%s: " fmt "\n", opt.progname, __VA_ARGS__); \
}
#define ERROR(msg) ERRORF("%s", msg)
static void print_help(void)
{
ERRORF("usage: %s [-m max-length] [-q] [-t]", opt.progname);
ERROR("Compress CIDR list by aggregrating adjacent subnets.");
}
static void *xmalloc(size_t s)
{
void *r = malloc(s);
if (r == NULL) {
ERROR("out of memory");
abort();
}
return r;
}
struct cidr {
struct cidr *next;
uint8_t prefix[16];
uint8_t mask;
};
static void len2mask(uint8_t buf[16], int len)
{
int i;
for (i = 0; i < 16; i++)
buf[i] = (i * 8 < len) ? 0xffU : 0x00U;
if (len % 8 != 0)
buf[len / 8] = 0xff00U >> (len % 8);
}
static int cidr_cmp(const struct cidr *a, const struct cidr *b)
{
return memcmp(&(a->prefix), &(b->prefix), 16);
}
static int cidr_truncate(struct cidr *c)
{
uint8_t mask[16];
len2mask(mask, c->mask);
int i, dotruncate = 0;
for (i = 0; i < 16; i++)
if (c->prefix[i] & ~mask[i])
dotruncate = 1;
if (!dotruncate)
return 0;
for (i = 0; i < 16; i++)
c->prefix[i] &= mask[i];
return 1;
}
static int cidr_contain(const struct cidr *a, const struct cidr *b)
{
if (a->mask < b->mask)
return 0;
if (a->mask == b->mask) {
if (!cidr_cmp(a, b))
return 1;
else
return 0;
}
struct cidr t = {
.mask = b->mask
};
memcpy(&(t.prefix), b->prefix, 16);
cidr_truncate(&t);
if (!cidr_cmp(a, &t))
return 1;
return 0;
}
static int cidr_adjacent(const struct cidr *a, const struct cidr *b)
{
if (a->mask != b->mask)
return 0;
uint8_t mask[16];
len2mask(mask, a->mask - 1);
int i;
for (i = 0; i < 16; i++)
if ((a->prefix[i] ^ b->prefix[i]) & mask[i])
return 0;
return 1;
}
static struct cidr *cidr_parse(char *buf)
{
char *slash = strchr(buf, '/');
if (slash == NULL) {
ERROR("discarding this line since it is not cidr");
return NULL;
}
*slash = 0;
int m = atoi(slash + 1);
if (m < 0 || m > 128 || (opt.max_length >= 0 && m > opt.max_length)) {
ERROR("discarding this line since its prefix length is out of range");
return NULL;
}
if (!opt.ipv6 && m > 32) {
opt.ipv6 = 1;
ERROR("using IPv6 on output");
}
uint8_t p[16];
if (strchr(buf, ':')) {
if (inet_pton(AF_INET6, buf, p) != 1) {
ERROR("invalid IPv6 prefix discarded");
return NULL;
}
if (!opt.ipv6) {
opt.ipv6 = 1;
ERROR("using IPv6 on output");
}
} else {
memcpy(p, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff", 12);
if (inet_pton(AF_INET, buf, &(p[12])) != 1) {
ERROR("invalid IPv4 prefix discarded");
return NULL;
}
m += 96;
}
struct cidr *r = xmalloc(sizeof(*r));
r->next = NULL;
memcpy(r->prefix, p, 16);
r->mask =m;
if (cidr_truncate(r)) {
if (!opt.truncate) {
ERROR("discarding this line since its prefix and length does not match");
free(r);
return NULL;
}
}
return r;
}
static void cidr_print(const struct cidr *c)
{
const size_t bufsz = 200;
char buf[200];
for (; c != 0; c = c->next) {
if (opt.ipv6) {
inet_ntop(AF_INET6, c->prefix, buf, bufsz);
printf("%s/%d\n", buf, c->mask);
} else {
inet_ntop(AF_INET, &(c->prefix[12]), buf, bufsz);
printf("%s/%d\n", buf, c->mask - 96);
}
}
}
static void cidr_freeall(struct cidr *c)
{
while (c != NULL) {
struct cidr *t = c->next;
free(c);
c = t;
}
}
static struct cidr *read_line(void)
{
const size_t bufsz = 200;
char line[bufsz];
if (!fgets(line, bufsz, stdin))
return NULL;
int eol = 0;
size_t nli = strlen(line) - 1;
if (line[nli] == '\n' || line[nli] == '\r')
eol = 1;
struct cidr *r = cidr_parse(line);
while (!eol) {
if (!fgets(line, bufsz, stdin))
return r;
nli = strlen(line) - 1;
if (line[nli] == '\n' || line[nli] == '\r')
eol = 1;
}
return r;
}
struct sorted_cidr {
struct sorted_cidr *next;
struct cidr *list;
struct cidr *tail;
size_t count;
};
static void sorted_cidr_collapse(struct sorted_cidr *s, int mergeall)
{
while (s->next != NULL && (mergeall || s->count >= s->next->count)) {
if (cidr_cmp(s->tail, s->next->list) <= 0) {
struct sorted_cidr *t = s->next;
s->tail->next = t->list;
s->tail = t->tail;
s->count += t->count;
s->next = t->next;
free(t);
} else if (cidr_cmp(s->list, s->next->tail) >= 0) {
struct sorted_cidr *t = s->next;
t->tail->next = s->list;
s->list = t->list;
s->count += t->count;
s->next = t->next;
free(t);
} else {
struct cidr *a = s->list, *b = s->next->list, *pa = NULL;
while (a != NULL && b != NULL) {
if (cidr_cmp(a, b) > 0) {
struct cidr *t = b->next;
b->next = a;
if (pa == NULL)
s->list = b;
else
pa->next = b;
pa = b;
b = t;
} else {
pa = a;
a = a->next;
}
}
if (a == NULL)
pa->next = b;
s->tail = pa;
while (s->tail->next != NULL)
s->tail = s->tail->next;
s->count += s->next->count;
struct sorted_cidr *t = s->next;
s->next = t->next;
free(t);
}
}
}
static struct cidr *read_and_sort(void)
{
struct sorted_cidr *s = NULL;
while (!feof(stdin)) {
struct cidr *n = read_line();
if (n == NULL)
continue;
if (s == NULL || s->count != 1) {
struct sorted_cidr *ns = xmalloc(sizeof(*ns));
ns->next = s;
ns->count = 1;
ns->list = n;
ns->tail = n;
s = ns;
} else { /* s->count == 1 */
s->count++;
if (cidr_cmp(n, s->list) < 0) {
n->next = s->list;
s->list = n;
} else {
s->list->next = n;
s->tail = n;
}
sorted_cidr_collapse(s, 0);
}
}
if (s == NULL)
return NULL;
sorted_cidr_collapse(s, 1);
struct cidr *r = s->list;
free(s);
return r;
}
static int opt_cidrs(struct cidr *c)
{
int r = 0;
struct cidr *i;
for (i = c; i != NULL && i->next != NULL; i = i->next) {
if (cidr_contain(i, i->next)) {
struct cidr *t = i->next;
i->next = t->next;
free(t);
r = 1;
} else if (cidr_adjacent(i, i->next)) {
struct cidr *t = i->next;
i->next = t->next;
free(t);
i->mask--;
r = 1;
}
}
return r;
}
static int aggregate(void)
{
struct cidr *nets = read_and_sort();
while (opt_cidrs(nets)) {};
cidr_print(nets);
cidr_freeall(nets);
return 0;
}
int main(int argc, char **argv)
{
opt.max_length = -1;
opt.quiet = 0;
opt.truncate = 0;
opt.progname = argc == 0 ? "aggregate" : argv[0];
opt.ipv6 = 0;
int c;
while ((c = getopt(argc, argv, "m:o:p:qtv")) > 0)
switch (c) {
case 'm':
opt.max_length = atoi(optarg);
if (opt.max_length < 0 || opt.max_length > 128) {
ERRORF("can't set maximum prefix length to %d",
opt.max_length);
return 1;
}
break;
case 'q':
opt.quiet = 1;
break;
case 't':
opt.truncate = 1;
break;
case 'o':
case 'p':
case 'v':
ERROR("option not supported");
return 1;
break;
default:
print_help();
return 1;
}
if (opt.max_length > 32) {
opt.ipv6 = 1;
ERROR("using IPv6 on output");
}
return aggregate();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment