Skip to content

Instantly share code, notes, and snippets.

@cypres
Last active March 6, 2020 17:03
Show Gist options
  • Save cypres/b83303b4988a4afb2a75 to your computer and use it in GitHub Desktop.
Save cypres/b83303b4988a4afb2a75 to your computer and use it in GitHub Desktop.
IPv4 in CIDR
#ifdef __FreeBSD__
#include <sys/socket.h>
#endif
#include <arpa/inet.h>
#include <netinet/in.h>
#include <cstdlib>
#include <cstring>
#include <cerrno>
#include <cassert>
#include <iostream>
#include <cstdint>
int inet_cidrtoaddr(int cidr, struct in_addr *addr) {
int ocets;
if (cidr < 0 || cidr > 32) {
errno = EINVAL;
return -1;
}
ocets = (cidr + 7) / 8;
addr->s_addr = 0;
if (ocets > 0) {
memset(&addr->s_addr, 255, (size_t)ocets - 1);
memset((unsigned char *)&addr->s_addr + (ocets - 1),
(256 - (1 << (32 - cidr) % 8)), 1);
}
return 0;
}
bool cidr_match(const in_addr &addr, const in_addr &net, uint8_t bits) {
if (bits == 0) {
// C99 6.5.7 (3): u32 << 32 is undefined behaviour
return true;
}
return !((addr.s_addr ^ net.s_addr) & htonl(0xFFFFFFFFu << (32 - bits)));
}
bool cidr6_match(const in6_addr &address, const in6_addr &network, uint8_t bits) {
#ifdef LINUX
const uint32_t *a = address.s6_addr32;
const uint32_t *n = network.s6_addr32;
#else
const uint32_t *a = address.__u6_addr.__u6_addr32;
const uint32_t *n = network.__u6_addr.__u6_addr32;
#endif
int bits_whole, bits_incomplete;
bits_whole = bits >> 5; // number of whole u32
bits_incomplete = bits & 0x1F; // number of bits in incomplete u32
if (bits_whole) {
if (memcmp(a, n, bits_whole << 2)) {
return false;
}
}
if (bits_incomplete) {
uint32_t mask = htonl((0xFFFFFFFFu) << (32 - bits_incomplete));
if ((a[bits_whole] ^ n[bits_whole]) & mask) {
return false;
}
}
return true;
}
int main() {
in_addr ip, net, netmask;
// Check if 192.168.0.17 is present in 192.168.0.0/24
inet_aton("192.168.0.17", &ip);
inet_aton("192.168.0.0", &net);
inet_cidrtoaddr(24, &netmask);
std::cout << "Got netmask: " << inet_ntoa(netmask) << std::endl;
bool a = ((ip.s_addr & netmask.s_addr) == (net.s_addr & netmask.s_addr));
std::cout << "Test A: " << a << std::endl;
assert(a);
// Another way
bool b = cidr_match(ip, net, 24);
std::cout << "Test B: " << b << std::endl;
assert(b);
// Check 0.0.0.0/0
in_addr all;
inet_aton("0.0.0.0", &all);
bool c = cidr_match(ip, all, 0);
std::cout << "Test C: " << c << std::endl;
assert(c);
// Check it does in fact return 0 if outside the range
in_addr outside;
inet_aton("192.168.1.0", &outside);
bool d = cidr_match(outside, net, 24);
std::cout << "Test D: " << d << std::endl;
assert(!d);
// Test specifying a wrong network address
// 192.168.0.1/24 should be 192.168.0.0/24
in_addr wrong_net;
inet_aton("192.168.0.1", &wrong_net);
bool e = ((ip.s_addr & netmask.s_addr) == (net.s_addr & netmask.s_addr));
std::cout << "Test E: " << e << std::endl;
assert(a);
bool f = cidr_match(ip, wrong_net, 24);
std::cout << "Test F: " << f << std::endl;
assert(f);
// Check that is in fact class less (not restricted to class C / 24)
// Make sure 192.168.0.17 is outside 192.168.0.0/28, inside 192.168.0.16/28
in_addr classless_net;
inet_aton("192.168.0.16", &classless_net);
bool g = cidr_match(ip, classless_net, 28);
bool h = cidr_match(ip, net, 28);
std::cout << "Test G: " << g << std::endl;
std::cout << "Test H: " << h << std::endl;
assert(g);
assert(!h);
// Throw in a little IPv6 too
in6_addr ip6, net6, net6_48;
memset(&net6, 0, sizeof(net6));
memset(&net6_48, 0, sizeof(net6_48));
// Check if 2001:db8::ff00:42:8329 is present in 2001:db8/32
// Beware inet_net_pton is very picky, it's 2001:db8/32 not 2001:db8::/32
// However 2001:db8::/48 is perfectly valid (also known as 2001:db8:0/48)
assert(inet_pton(AF_INET6, "2001:db8::ff00:42:8329", &ip6));
int bits = inet_net_pton(AF_INET6, "2001:db8/32", &net6, sizeof(net6));
assert((bits != -1)); // assert that inet_net_pton understood us
bool i = cidr6_match(ip6, net6, bits);
std::cout << "Test I: " << i << std::endl;
assert(i);
// Check against the smaller /48 too
int bits_48 = inet_net_pton(AF_INET6, "2001:db8::/48", &net6_48, sizeof(net6_48));
assert((bits_48 == 48));
bool j = cidr6_match(ip6, net6_48, bits_48);
std::cout << "Test J: " << j << std::endl;
assert(j);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment