Skip to content

Instantly share code, notes, and snippets.

@JSmol

JSmol/trie.py Secret

Created January 13, 2021 21:31
Show Gist options
  • Save JSmol/284484a9b70928eb6098319da62ccfeb to your computer and use it in GitHub Desktop.
Save JSmol/284484a9b70928eb6098319da62ccfeb to your computer and use it in GitHub Desktop.
Implementation of trie data structure
# Trie implementation for algorithms class.
# The concept of a "near miss" factor is given in the written portion.
# Essentially it is a count of missmatches between 2 "words".
import sys
from time import time
from random import randint, choices
from string import ascii_lowercase as al
# trie[pos][char] = next row
trie = [[0] * 27]
# get index of next letter
def index(c):
return al.find(c)
# add key to trie
def add_key(w):
i, j = 0, 0
while j < len(w):
c = index(w[j])
k = trie[i][c]
if k:
i = k
j += 1
else:
trie[i][c] = len(trie)
trie.append([0] * 27)
trie[i][26] = w
# find if the word is in the trie
def find(w):
i = 0
for l in w:
c = index(l)
i = trie[i][c]
if not i:
return None
return trie[i][26]
# nieve prefix search for comparison
def nieve(s):
for w in l:
if w.startswith(s):
T.add(w)
# dfs for prefix search
# S is a global variable of all found words
def search(i):
for j in range(26):
k = trie[i][j]
if k: search(k)
w = trie[i][26]
if w: S.add(w)
# find all words with prefix s
def prefix(s):
i = 0
for l in s:
c = index(l)
i = trie[i][c]
if not i:
return None
search(i)
# recursive search for near miss strings
def dfs(s, f, i, j):
if f < 0:
return
for k in range(26):
l = trie[j][k]
if l and i < len(s) and s[i] == al[k]:
dfs(s, f, i+1, l)
elif l:
dfs(s, f-1, i+1, l)
w = trie[j][26]
if w and abs(len(s) - len(w)) <= f:
S.add(w)
# find strings that are close to w
# f is the "miss factor"
def nearmiss(s, f):
dfs(s, f, 0, 0)
with open("words") as f:
l = list(f.read().split())
print("constructing trie...")
start = time()
for w in l:
add_key(w)
print(f"trie built in {time() - start}")
print(f"size of trie in bytes {sys.getsizeof(trie)}")
print(f"size of list in bytes {sys.getsizeof(l)}")
while True:
print("--- --- --- --- --- --- --- --- --- --- ---")
s = input("enter a word\n").lower()
print()
start = time()
S = set()
prefix(s)
print(f"query complete in {time() - start}")
print()
start = time()
T = set()
nieve(s)
print(f"nieve query complete in {time() - start}")
print()
# both methods should yield equal results
assert(S == T)
print(f"found {len(S)} words with prefix {s}")
print(*S, sep=", ")
print()
f = int(input("enter a near miss factor\n"))
print()
start = time()
S = set()
nearmiss(s, f)
print(f"near miss complete in {time() - start}")
print()
print(f"found {len(S)} words with miss factor of {f}")
print(*S, sep=", ")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment