Skip to content

Instantly share code, notes, and snippets.

@davidbalbert
Last active April 13, 2023 17:26
Show Gist options
  • Save davidbalbert/2d326ae2827fc948a47b5eaa15c1620d to your computer and use it in GitHub Desktop.
Save davidbalbert/2d326ae2827fc948a47b5eaa15c1620d to your computer and use it in GitHub Desktop.
func commonPrefixLen(a, b string) int {
i := 0
for ; i < len(a) && i < len(b); i++ {
if a[i] != b[i] {
break
}
}
return i
}
type edge struct {
label string
node *node
}
type node struct {
hasValue bool
value any
edgeIndex []byte
edges []*edge
}
func (n *node) store(key string, value any) {
for {
if len(key) == 0 {
n.hasValue = true
n.value = value
return
}
i := sort.Search(len(n.edgeIndex), func(i int) bool {
return n.edgeIndex[i] >= key[0]
})
if i < len(n.edgeIndex) && n.edgeIndex[i] == key[0] {
// edge found
e := n.edges[i]
prefixLen := commonPrefixLen(e.label, key)
if prefixLen == len(e.label) && prefixLen == len(key) {
// exact match, overwrite
e.node.hasValue = true
e.node.value = value
return
} else if prefixLen == len(e.label) {
// e.label is a prefix of key
key = key[prefixLen:]
n = e.node
} else {
// prefixLen < len(n.label) && prefixLen < len(key)
// split
intermediateNode := &node{
edgeIndex: []byte{e.label[prefixLen]},
edges: []*edge{{label: e.label[prefixLen:], node: e.node}},
}
e.label = e.label[:prefixLen]
e.node = intermediateNode
key = key[prefixLen:]
n = intermediateNode
}
} else if i < len(n.edgeIndex) {
// insert edge
n.edgeIndex = append(n.edgeIndex, 0)
copy(n.edgeIndex[i+1:], n.edgeIndex[i:])
n.edgeIndex[i] = key[0]
n.edges = append(n.edges, nil)
copy(n.edges[i+1:], n.edges[i:])
n.edges[i] = &edge{label: key, node: &node{hasValue: true, value: value}}
return
} else {
// append edge
n.edgeIndex = append(n.edgeIndex, key[0])
n.edges = append(n.edges, &edge{label: key, node: &node{hasValue: true, value: value}})
return
}
}
}
func (n *node) load(key string) (any, bool) {
if len(key) == 0 {
return n.value, n.hasValue
}
for {
i := sort.Search(len(n.edgeIndex), func(i int) bool {
return n.edgeIndex[i] >= key[0]
})
if i < len(n.edgeIndex) && n.edgeIndex[i] == key[0] {
// edge found
e := n.edges[i]
prefixLen := radixTreeCommonPrefixLen(e.label, key)
if prefixLen == len(e.label) && prefixLen == len(key) {
// exact match
return e.node.value, e.node.hasValue
} else if prefixLen == len(e.label) {
// e.label is a prefix of key
key = key[prefixLen:]
n = e.node
} else {
// prefixLen < len(n.label) && prefixLen < len(key)
return nil, false
}
} else {
// no edge found
return nil, false
}
}
}
type walkPartialTokensFunc func(key string, value any) error
// walkPartialTokens tokenizes keys in the tree using sep as a separator, and calls fn for each
// key that matches the query. The query is tokenized using the same separator, and each token
// in the query must be a prefix of a corresponding token in the key. The number of tokens in
// each matched key must match the number of tokens in the query.
//
// E.g. if sep is ' ', then the query "fo ba" will match the keys "foo bar" and "foo baz", but not
// "foo bar baz". As a special case, a query of "" will match the key "", and nothing else, for any
// value of sep.
func (root *node) walkPartialTokens(query string, sep byte, fn walkPartialTokensFunc) error {
queryParts := strings.FieldsFunc(query, func(r rune) bool {
return r == rune(sep)
})
// special case: if the query is empty, we match the key "".
if len(queryParts) == 0 {
if root.hasValue {
return fn("", root.value)
}
return nil
}
var walkNode func(prefix string, n *node, tokPrefix string, tokPrefixes []string) error
var walkEdge func(prefix string, e *edge, offset int, tokPrefix string, tokPrefixes []string) error
var walkUntilSep func(prefix string, e *edge, offset int, tokPrefixes []string) error
walkNode = func(prefix string, n *node, tokPrefix string, tokPrefixes []string) error {
// walkNode is always called with len(tokPrefix) > 0
i := sort.Search(len(n.edgeIndex), func(i int) bool {
return n.edgeIndex[i] >= tokPrefix[0]
})
if i == len(n.edgeIndex) || n.edgeIndex[i] != tokPrefix[0] {
// no edge found
return nil
}
edge := n.edges[i]
return walkEdge(prefix, edge, 0, tokPrefix, tokPrefixes)
}
walkEdge = func(prefix string, e *edge, offset int, partialToken string, partialTokens []string) error {
rest := e.label[offset:]
prefixLen := radixTreeCommonPrefixLen(rest, partialToken)
if prefixLen < len(partialToken) && prefixLen < len(rest) {
// neither the edge nor partialToken is a prefix of the other. no match.
return nil
} else if prefixLen < len(partialToken) {
// partialToken continues past the end of the edge (i.e. rest is a prefix of partialToken).
// Keep searching at the next node. partialToken[prefixLen:] is guaranteed to be non-empty.
return walkNode(prefix+rest, e.node, partialToken[prefixLen:], partialTokens)
} else if prefixLen < len(rest) {
// partialToken ends inside the edge (i.e. partialToken is a prefix of rest).
// Start searching for separator on this edge.
return walkUntilSep(prefix+rest[:prefixLen], e, offset+prefixLen, partialTokens)
} else {
// partialToken == rest
// Start searching for separator starting at the next node.
node := e.node
if node.hasValue && len(partialTokens) == 0 {
err := fn(prefix+rest, node.value)
if err != nil {
return err
}
}
for _, e := range node.edges {
err := walkUntilSep(prefix+rest, e, 0, partialTokens)
if err != nil {
return err
}
}
return nil
}
}
walkUntilSep = func(prefix string, e *edge, offset int, partialTokens []string) error {
suffix := e.label[offset:]
i := strings.Index(suffix, string(sep))
if i == -1 {
// no separator
if len(partialTokens) == 0 {
// no more partial tokens, so we've found a match
if e.node.hasValue {
err := fn(prefix+suffix, e.node.value)
if err != nil {
return err
}
}
}
for _, e := range e.node.edges {
err := walkUntilSep(prefix+suffix, e, 0, partialTokens)
if err != nil {
return err
}
}
return nil
} else if len(partialTokens) == 0 {
// we found a separator on this edge, but have no more partial tokens, so stop here
return nil
} else if i == len(suffix)-1 {
return walkNode(prefix+suffix, e.node, partialTokens[0], partialTokens[1:])
} else {
return walkEdge(prefix+suffix[:i+1], e, offset+i+1, partialTokens[0], partialTokens[1:])
}
}
return walkNode("", root, queryParts[0], queryParts[1:])
}
package main
import (
"reflect"
"testing"
)
func TestStoreAndLoad(t *testing.T) {
n := &node{}
n.store("foobar", 1)
n.store("foo", 2)
n.store("bar", 3)
value, ok := n.load("foobar")
if !ok {
t.Fatal("expected foobar to be found")
} else if value != 1 {
t.Fatalf("expected foobar to be 1, got %d", value)
}
value, ok = n.load("foo")
if !ok {
t.Fatal("expected foo to be found")
} else if value != 2 {
t.Fatalf("expected foo to be 2, got %d", value)
}
value, ok = n.load("bar")
if !ok {
t.Fatal("expected bar to be found")
} else if value != 3 {
t.Fatalf("expected bar to be 3, got %d", value)
}
}
func TestLoadNotFound(t *testing.T) {
n := &node{}
n.store("foobar", 1)
n.store("foo", 2)
_, ok := n.load("fo")
if ok {
t.Fatal("expected fo to not be found")
}
_, ok = n.load("foob")
if ok {
t.Fatal("expected foob to not be found")
}
_, ok = n.load("bar")
if ok {
t.Fatal("expected bar to not be found")
}
}
func TestStoreAndLoadEmptyKey(t *testing.T) {
n := &node{}
n.store("", 1)
value, ok := n.load("")
if !ok {
t.Fatal("expected empty key to be found")
} else if value != 1 {
t.Fatalf("expected empty key to be 1, got %d", value)
}
}
func TestNonExistantEmptyKeyLoad(t *testing.T) {
n := &node{}
_, ok := n.load("")
if ok {
t.Fatal("expected empty key to not be found")
}
n.store("foo", 1)
_, ok = n.load("")
if ok {
t.Fatal("expected empty key to not be found")
}
}
func TestOverwrite(t *testing.T) {
n := &node{}
n.store("foo", 1)
value, ok := n.load("foo")
if !ok {
t.Fatal("expected foo to be found")
} else if value != 1 {
t.Fatalf("expected foo to be 1, got %d", value)
}
n.store("foo", 2)
value, ok = n.load("foo")
if !ok {
t.Fatal("expected foo to be found")
} else if value != 2 {
t.Fatalf("expected foo to be 2, got %d", value)
}
}
func TestWalkPartialTokensExactMatch(t *testing.T) {
n := &node{}
n.store("show version", 1)
n.store("show version detail", 2)
n.store("show name", 3)
n.store("show number", 4)
var keys []string
var values []int
n.walkPartialTokens("show version", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string{"show version"}
expectedValues := []int{1}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensPrefixMatch(t *testing.T) {
n := &node{}
n.store("show version", 1)
n.store("show version detail", 2)
n.store("show name", 3)
n.store("show number", 4)
var keys []string
var values []int
n.walkPartialTokens("sh ver", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string{"show version"}
expectedValues := []int{1}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensMultipleMatches(t *testing.T) {
n := &node{}
n.store("show version", 1)
n.store("show version detail", 2)
n.store("show name", 3)
n.store("show number", 4)
var keys []string
var values []int
n.walkPartialTokens("sh n", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string{"show name", "show number"}
expectedValues := []int{3, 4}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
keys = nil
values = nil
n.walkPartialTokens("sh na", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys = []string{"show name"}
expectedValues = []int{3}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
keys = nil
values = nil
n.walkPartialTokens("sh nu", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys = []string{"show number"}
expectedValues = []int{4}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensEdgeDoesntMatch(t *testing.T) {
n := &node{}
n.store("show version", 1)
n.store("show version detail", 2)
n.store("show name", 3)
n.store("show number", 4)
var keys []string
var values []int
n.walkPartialTokens("shaw", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string(nil)
expectedValues := []int(nil)
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensTooFewTokens(t *testing.T) {
n := &node{}
n.store("show version", 1)
n.store("show version detail", 2)
n.store("show name", 3)
n.store("show number", 4)
var keys []string
var values []int
n.walkPartialTokens("sh", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string(nil)
expectedValues := []int(nil)
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensTooMany(t *testing.T) {
n := &node{}
n.store("show version", 1)
n.store("show version detail", 2)
n.store("show name", 3)
n.store("show number", 4)
var keys []string
var values []int
n.walkPartialTokens("sh na foo", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string(nil)
expectedValues := []int(nil)
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensEmptyQuery(t *testing.T) {
n := &node{}
n.store("", 1)
var keys []string
var values []int
n.walkPartialTokens("", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string{""}
expectedValues := []int{1}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensEmptyQueryNoMatch(t *testing.T) {
n := &node{}
n.store("show version", 1)
var keys []string
var values []int
n.walkPartialTokens("", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string(nil)
expectedValues := []int(nil)
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensEmptyTree(t *testing.T) {
n := &node{}
var keys []string
var values []int
n.walkPartialTokens("sh ver", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string(nil)
expectedValues := []int(nil)
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
func TestWalkPartialTokensSeparatorWithinEdge(t *testing.T) {
n := &node{}
n.store("foo bar baz", 1)
n.store("foo bar", 2)
var keys []string
var values []int
n.walkPartialTokens("foo bar", ' ', func(prefix string, value any) error {
keys = append(keys, prefix)
values = append(values, value.(int))
return nil
})
expectedKeys := []string{"foo bar"}
expectedValues := []int{2}
if !reflect.DeepEqual(keys, expectedKeys) {
t.Fatalf("expected prefixes %#v, got %#v", expectedKeys, keys)
}
if !reflect.DeepEqual(values, expectedValues) {
t.Fatalf("expected values %#v, got %#v", expectedValues, values)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment