Skip to content

Instantly share code, notes, and snippets.

@zdebra
Created April 15, 2023 08:36
Show Gist options
  • Save zdebra/74c059fbb7d2f5fbaf72efd20f857b8c to your computer and use it in GitHub Desktop.
Save zdebra/74c059fbb7d2f5fbaf72efd20f857b8c to your computer and use it in GitHub Desktop.
Trie implementation in go
package main
import "fmt"
type node struct {
children map[rune]*node
value rune
}
func newNode(v rune) *node {
return &node{
children: map[rune]*node{},
value: v,
}
}
func insertWord(start map[rune]*node, word []rune) {
if len(word) == 0 {
return
}
// lazy init start
if _, found := start[word[0]]; !found {
start[word[0]] = newNode(word[0])
}
curNode := start[word[0]]
for i := 1; i < len(word); i++ {
if _, found := curNode.children[word[i]]; !found {
curNode.children[word[i]] = newNode(word[i])
}
curNode = curNode.children[word[i]]
}
}
func find(start map[rune]*node, input []rune) ([][]rune, bool) {
if len(input) == 0 {
return [][]rune{}, true
}
curNode, found := start[input[0]]
if !found {
return nil, false
}
for i := 1; i < len(input); i++ {
children, found := curNode.children[input[i]]
if !found {
return nil, false
}
curNode = children
}
// complete the rest of prefix matched words
tails := dfs(curNode, input[:len(input)-1])
return tails, true
}
func dfs(n *node, prefix []rune) [][]rune {
prefix = append(prefix, n.value)
if len(n.children) == 0 {
return [][]rune{prefix}
}
out := [][]rune{}
for _, children := range n.children {
prefixClone := make([]rune, len(prefix))
copy(prefixClone, prefix)
downstreamWords := dfs(children, prefixClone)
out = append(out, downstreamWords...)
}
return out
}
func main() {
start := map[rune]*node{}
dict := []string{"banana", "ball", "car", "color"}
for _, w := range dict {
insertWord(start, []rune(w))
}
find_test(start, "ba")
find_test(start, "ban")
find_test(start, "c")
find_test(start, "cx")
}
func find_test(start map[rune]*node, input string) {
fmt.Printf("searching prefix %q\n", input)
words, found := find(start, []rune(input))
if !found {
fmt.Println("not found")
return
}
for _, w := range words {
fmt.Printf("%s\t", string(w))
}
fmt.Println()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment