Skip to content

Instantly share code, notes, and snippets.

@gaby
Forked from phemmer/trie.go
Created March 24, 2024 21:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gaby/c9bd7de2983da50111b26b6166e6ccc8 to your computer and use it in GitHub Desktop.
Save gaby/c9bd7de2983da50111b26b6166e6ccc8 to your computer and use it in GitHub Desktop.
optmized fork of yl2chen/cidranger
// Package iptrie is a fork of github.com/yl2chen/cidranger. This fork massively strips down and refactors the code for
// increased performance, resulting in 20x faster load time, and 1.5x faster lookups.
package iptrie
import (
"fmt"
"math/bits"
"net/netip"
"strings"
"unsafe"
)
// Trie is an IP radix trie implementation, similar to what is described
// at https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux
//
// CIDR blocks are stored using a prefix tree structure where each node has its
// parent as prefix, and the path from the root node represents current CIDR
// block.
//
// Path compression compresses a string of node with only 1 child into a single
// node, decrease the amount of lookups necessary during containment tests.
type Trie struct {
parent *Trie
children [2]*Trie
network netip.Prefix
value any
}
// NewTrie creates a new Trie.
func NewTrie() *Trie {
return &Trie{
network: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
}
}
func newSubTree(network netip.Prefix, value any) *Trie {
return &Trie{
network: network,
value: value,
}
}
// Insert inserts a RangerEntry into prefix trie.
func (p *Trie) Insert(network netip.Prefix, value any) {
network = normalizePrefix(network)
p.insert(network, value)
}
// Remove removes RangerEntry identified by given network from trie.
func (p *Trie) Remove(network netip.Prefix) any {
network = normalizePrefix(network)
return p.remove(network)
}
// Find returns the value from the smallest prefix containing the given address.
func (p *Trie) Find(ip netip.Addr) any {
ip = normalizeAddr(ip)
return p.find(ip)
}
// ContainingNetworks returns the list of RangerEntry(s) the given ip is
// contained in in ascending prefix order.
func (p *Trie) ContainingNetworks(ip netip.Addr) []netip.Prefix {
ip = normalizeAddr(ip)
return p.containingNetworks(ip)
}
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers. That is, the networks that are completely subsumed by the
// specified network.
func (p *Trie) CoveredNetworks(network netip.Prefix) []netip.Prefix {
network = normalizePrefix(network)
return p.coveredNetworks(network)
}
func (p *Trie) Network() netip.Prefix {
return p.network
}
// String returns string representation of trie, mainly for visualization and
// debugging.
func (p *Trie) String() string {
children := []string{}
padding := strings.Repeat("| ", p.level()+1)
for bit, child := range p.children {
if child == nil {
continue
}
childStr := fmt.Sprintf("\n%s%d--> %s", padding, bit, child.String())
children = append(children, childStr)
}
return fmt.Sprintf("%s (has_entry:%t)%s", p.network,
p.value != nil, strings.Join(children, ""))
}
func (p *Trie) find(number netip.Addr) any {
if !netContains(p.network, number) {
return nil
}
if p.value != nil {
return p.value
}
if p.network.Bits() == 128 {
return nil
}
bit := p.discriminatorBitFromIP(number)
child := p.children[bit]
if child != nil {
return child.find(number)
}
return nil
}
func (p *Trie) containingNetworks(addr netip.Addr) []netip.Prefix {
var results []netip.Prefix
if !p.network.Contains(addr) {
return results
}
if p.value != nil {
results = []netip.Prefix{p.network}
}
if p.network.Bits() == 128 {
return results
}
bit := p.discriminatorBitFromIP(addr)
child := p.children[bit]
if child != nil {
ranges := child.containingNetworks(addr)
if len(ranges) > 0 {
if len(results) > 0 {
results = append(results, ranges...)
} else {
results = ranges
}
}
}
return results
}
func (p *Trie) coveredNetworks(network netip.Prefix) []netip.Prefix {
var results []netip.Prefix
if network.Bits() <= p.network.Bits() && network.Contains(p.network.Addr()) {
for entry := range p.walkDepth() {
results = append(results, entry)
}
} else if p.network.Bits() < 128 {
bit := p.discriminatorBitFromIP(network.Addr())
child := p.children[bit]
if child != nil {
return child.coveredNetworks(network)
}
}
return results
}
// This is an unsafe, but faster version of netip.Prefix.Contains
func netContains(pfx netip.Prefix, ip netip.Addr) bool {
pfxAddr := addr128(pfx.Addr())
ipAddr := addr128(ip)
return ipAddr.xor(pfxAddr).and(mask6(pfx.Bits())).isZero()
}
// netDivergence returns the largest prefix shared by the provided 2 prefixes
func netDivergence(net1 netip.Prefix, net2 netip.Prefix) netip.Prefix {
if net1.Bits() > net2.Bits() {
net1, net2 = net2, net1
}
if netContains(net1, net2.Addr()) {
return net1
}
diff := addr128(net1.Addr()).xor(addr128(net2.Addr()))
var bit int
if diff.hi != 0 {
bit = bits.LeadingZeros64(diff.hi)
} else {
bit = bits.LeadingZeros64(diff.lo) + 64
}
if bit > net1.Bits() {
bit = net1.Bits()
}
pfx, _ := net1.Addr().Prefix(bit)
return pfx
}
func (p *Trie) insert(network netip.Prefix, value any) *Trie {
if p.network == network {
p.value = value
return p
}
bit := p.discriminatorBitFromIP(network.Addr())
existingChild := p.children[bit]
// No existing child, insert new leaf trie.
if existingChild == nil {
pNew := newSubTree(network, value)
p.appendTrie(bit, pNew)
return pNew
}
// Check whether it is necessary to insert additional path prefix between current trie and existing child,
// in the case that inserted network diverges on its path to existing child.
netdiv := netDivergence(existingChild.network, network)
if netdiv != existingChild.network {
pathPrefix := newSubTree(netdiv, nil)
p.insertPrefix(bit, pathPrefix, existingChild)
// Update new child
existingChild = pathPrefix
}
return existingChild.insert(network, value)
}
func (p *Trie) appendTrie(bit uint8, prefix *Trie) {
p.children[bit] = prefix
prefix.parent = p
}
func (p *Trie) insertPrefix(bit uint8, pathPrefix, child *Trie) {
// Set parent/child relationship between current trie and inserted pathPrefix
p.children[bit] = pathPrefix
pathPrefix.parent = p
// Set parent/child relationship between inserted pathPrefix and original child
pathPrefixBit := pathPrefix.discriminatorBitFromIP(child.network.Addr())
pathPrefix.children[pathPrefixBit] = child
child.parent = pathPrefix
}
func (p *Trie) remove(network netip.Prefix) any {
if p.value != nil && p.network == network {
entry := p.value
p.value = nil
p.compressPathIfPossible()
return entry
}
if p.network.Bits() == 128 {
return nil
}
bit := p.discriminatorBitFromIP(network.Addr())
child := p.children[bit]
if child != nil {
return child.remove(network)
}
return nil
}
func (p *Trie) qualifiesForPathCompression() bool {
// Current prefix trie can be path compressed if it meets all following.
// 1. records no CIDR entry
// 2. has single or no child
// 3. is not root trie
return p.value == nil && p.childrenCount() <= 1 && p.parent != nil
}
func (p *Trie) compressPathIfPossible() {
if !p.qualifiesForPathCompression() {
// Does not qualify to be compressed
return
}
// Find lone child.
var loneChild *Trie
for _, child := range p.children {
if child != nil {
loneChild = child
break
}
}
// Find root of currnt single child lineage.
parent := p.parent
for ; parent.qualifiesForPathCompression(); parent = parent.parent {
}
parentBit := parent.discriminatorBitFromIP(p.network.Addr())
parent.children[parentBit] = loneChild
// Attempts to furthur apply path compression at current lineage parent, in case current lineage
// compressed into parent.
parent.compressPathIfPossible()
}
func (p *Trie) childrenCount() int {
count := 0
for _, child := range p.children {
if child != nil {
count++
}
}
return count
}
func (p *Trie) discriminatorBitFromIP(addr netip.Addr) uint8 {
// This is a safe uint boxing of int since we should never attempt to get
// target bit at a negative position.
pos := p.network.Bits()
a128 := addr128(addr)
if pos < 64 {
return uint8(a128.hi >> (63 - pos) & 1)
}
return uint8(a128.lo >> (63 - (pos - 64)) & 1)
}
func (p *Trie) level() int {
if p.parent == nil {
return 0
}
return p.parent.level() + 1
}
// walkDepth walks the trie in depth order, for unit testing.
func (p *Trie) walkDepth() <-chan netip.Prefix {
entries := make(chan netip.Prefix)
go func() {
if p.value != nil {
entries <- p.network
}
childEntriesList := []<-chan netip.Prefix{}
for _, trie := range p.children {
if trie == nil {
continue
}
childEntriesList = append(childEntriesList, trie.walkDepth())
}
for _, childEntries := range childEntriesList {
for entry := range childEntries {
entries <- entry
}
}
close(entries)
}()
return entries
}
// TrieLoader can be used to improve the performance of bulk inserts to a Trie. It caches the node of the
// last insert in the tree, using it as the starting point to start searching for the location of the next insert. This
// is highly beneficial when the addresses are pre-sorted.
type TrieLoader struct {
trie *Trie
lastInsert *Trie
}
func NewTrieLoader(trie *Trie) *TrieLoader {
return &TrieLoader{
trie: trie,
lastInsert: trie,
}
}
func (ptl *TrieLoader) Insert(pfx netip.Prefix, v any) {
pfx = normalizePrefix(pfx)
diff := addr128(ptl.lastInsert.network.Addr()).xor(addr128(pfx.Addr()))
var pos int
if diff.hi != 0 {
pos = bits.LeadingZeros64(diff.hi)
} else {
pos = bits.LeadingZeros64(diff.lo) + 64
}
if pos > pfx.Bits() {
pos = pfx.Bits()
}
if pos > ptl.lastInsert.network.Bits() {
pos = ptl.lastInsert.network.Bits()
}
parent := ptl.lastInsert
for parent.network.Bits() > pos {
parent = parent.parent
}
ptl.lastInsert = parent.insert(pfx, v)
}
func normalizeAddr(addr netip.Addr) netip.Addr {
if addr.Is4() {
return netip.AddrFrom16(addr.As16())
}
return addr
}
func normalizePrefix(pfx netip.Prefix) netip.Prefix {
if pfx.Addr().Is4() {
pfx = netip.PrefixFrom(netip.AddrFrom16(pfx.Addr().As16()), pfx.Bits()+96)
}
return pfx.Masked()
}
func addr128(addr netip.Addr) uint128 {
return *(*uint128)(unsafe.Pointer(&addr))
}
func init() {
// Accessing the underlying data of a `netip.Addr` relies upon the data being
// in a known format, which is not guaranteed to be stable. So this init()
// function is to detect if it ever changes.
ip := netip.AddrFrom16([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})
i128 := addr128(ip)
if i128.hi != 0x0001020304050607 || i128.lo != 0x08090a0b0c0d0e0f {
panic("netip.Addr format mismatch")
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package iptrie
import "math/bits"
// uint128 represents a uint128 using two uint64s.
//
// When the methods below mention a bit number, bit 0 is the most
// significant bit (in hi) and bit 127 is the lowest (lo&1).
type uint128 struct {
hi uint64
lo uint64
}
// mask6 returns a uint128 bitmask with the topmost n bits of a
// 128-bit number.
func mask6(n int) uint128 {
return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)}
}
// isZero reports whether u == 0.
//
// It's faster than u == (uint128{}) because the compiler (as of Go
// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in
// its eq alg's generated code.
func (u uint128) isZero() bool { return u.hi|u.lo == 0 }
// and returns the bitwise AND of u and m (u&m).
func (u uint128) and(m uint128) uint128 {
return uint128{u.hi & m.hi, u.lo & m.lo}
}
// xor returns the bitwise XOR of u and m (u^m).
func (u uint128) xor(m uint128) uint128 {
return uint128{u.hi ^ m.hi, u.lo ^ m.lo}
}
// or returns the bitwise OR of u and m (u|m).
func (u uint128) or(m uint128) uint128 {
return uint128{u.hi | m.hi, u.lo | m.lo}
}
// not returns the bitwise NOT of u.
func (u uint128) not() uint128 {
return uint128{^u.hi, ^u.lo}
}
// subOne returns u - 1.
func (u uint128) subOne() uint128 {
lo, borrow := bits.Sub64(u.lo, 1, 0)
return uint128{u.hi - borrow, lo}
}
// addOne returns u + 1.
func (u uint128) addOne() uint128 {
lo, carry := bits.Add64(u.lo, 1, 0)
return uint128{u.hi + carry, lo}
}
// halves returns the two uint64 halves of the uint128.
//
// Logically, think of it as returning two uint64s.
// It only returns pointers for inlining reasons on 32-bit platforms.
func (u *uint128) halves() [2]*uint64 {
return [2]*uint64{&u.hi, &u.lo}
}
// bitsSetFrom returns a copy of u with the given bit
// and all subsequent ones set.
func (u uint128) bitsSetFrom(bit uint8) uint128 {
return u.or(mask6(int(bit)).not())
}
// bitsClearedFrom returns a copy of u with the given bit
// and all subsequent ones cleared.
func (u uint128) bitsClearedFrom(bit uint8) uint128 {
return u.and(mask6(int(bit)))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment