Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active December 8, 2022 04:41
Show Gist options
  • Save kastnerkyle/617637bb51c14d928e2f201d838f8498 to your computer and use it in GitHub Desktop.
Save kastnerkyle/617637bb51c14d928e2f201d838f8498 to your computer and use it in GitHub Desktop.
String matching using fft / fht (hartley transform). For learning, not optimal speed
# Author: Kyle Kastner
# BSD 3-Clause
# Thanks to jakevdp for the nice blog post on FFT
# https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/
# Summary
# http://www.arazim-project.com/sites/default/files/public/lesson_sums/1fft.pdf
# Details on hartley and many xforms
# https://caxapa.ru/thumbs/455725/algorithms.pdf
# pg 332 http://sep.stanford.edu/data/media/public/oldreports/sep38/38_29.pdf
# https://dsp.stackexchange.com/questions/22320/implement-fast-hartley-transform
# https://stackoverflow.com/questions/67940011/what-is-wrong-in-my-code-using-dht-to-compute-convolution
# https://www.embedded.com/doing-hartley-smartly/
# http://www.theradixpoint.com/hartley/hartley.html
# http://www.cs.cmu.edu/afs/cs/academic/class/15750-s16/Handouts/WildCards2006.pdf
import numpy as np
_dft_cache = {}
def dft_slow(x):
x = np.asarray(x)
N = x.shape[0]
if N not in _dft_cache:
n = np.arange(N)
k = n.reshape((N, 1))
M = np.exp(-2j * np.pi * k * n / N)
else:
M = _dft_cache[N]
return np.dot(M, x)
_idft_cache = {}
def idft_slow(x):
# complex data in
x = np.asarray(x)
N = x.shape[0]
if N not in _idft_cache:
n = np.arange(N)
k = n.reshape((N, 1))
M = np.exp(2j * np.pi * k * n / N)
else:
M = _idft_cache[N]
r = np.dot(M, x)
return 1./N * r
def fft_basic(x):
x = np.asarray(x)
N = x.shape[0]
if N % 2 > 0:
raise ValueError("size of x must be a power of 2")
elif N <= 32: # this cutoff should be optimized
return dft_slow(x)
else:
X_even = fft_basic(x[::2])
X_odd = fft_basic(x[1::2])
factor = np.exp(-2j * np.pi * np.arange(N) / N)
return np.concatenate([X_even + factor[:int(N // 2)] * X_odd,
X_even + factor[int(N // 2):] * X_odd])
def ifft_basic(x):
x = np.asarray(x)
x_con = np.conjugate(x)
X = fft_basic(x_con)
X = np.conjugate(X)
X = X / x_con.shape[0]
return X
def fft_vectorized(x):
x = np.asarray(x)
N = x.shape[0]
if np.log2(N) % 1 > 0:
raise ValueError("size of x must be a power of 2")
# N_min here is equivalent to the stopping condition above,
# and should be a power of 2
N_min = min(N, 32)
# Perform an O[N^2] DFT on all length-N_min sub-problems at once
n = np.arange(N_min)
k = n[:, None]
M = np.exp(-2j * np.pi * n * k / N_min)
X = np.dot(M, x.reshape((N_min, -1)))
# build-up each level of the recursive calculation all at once
while X.shape[0] < N:
X_even = X[:, :int(X.shape[1] // 2)]
X_odd = X[:, int(X.shape[1] // 2):]
factor = np.exp(-1j * np.pi * np.arange(X.shape[0]) / X.shape[0])[:, None]
X = np.vstack([X_even + factor * X_odd, X_even - factor * X_odd])
return X.ravel()
def ifft_vectorized(x):
x = np.asarray(x)
x_con = np.conjugate(x)
X = fft_vectorized(x_con)
X = np.conjugate(X)
X = X / x_con.shape[0]
return X
def dht_by_npfft(x):
X = np.fft.fft(x)
X = np.real(X) - np.imag(X)
return X
def idht_by_npfft(X):
n = len(X)
X = dht_by_npfft(X)
X = 1./n * X
return X
def cas(x):
return np.sin(x) + np.cos(x)
_dht_cache = {}
def dht_slow(x):
x = np.asarray(x)
N = x.shape[0]
if N not in _dft_cache:
n = np.arange(N)
k = n.reshape((N, 1))
M = cas(2 * np.pi * k * n / N)
else:
M = _dht_cache[N]
return np.dot(M, x)
_idht_cache = {}
def idht_slow(X):
n = len(X)
X = dht_slow(X)
X = 1./n * X
return X
def fht_basic(x):
x = np.asarray(x)
N = x.shape[0]
if N % 2 > 0:
raise ValueError("size of x must be a power of 2")
elif N <= 32: # this cutoff should be optimized
return dht_slow(x)
else:
X_even = fht_basic(x[::2])
X_odd = fht_basic(x[1::2])
roll_idx = [el for el in range(len(X_odd))]
# 0 el same, rest is reversed
# shown as (variably (N - k), or a with bar over it)
# https://caxapa.ru/thumbs/455725/algorithms.pdf
# pg 332 http://sep.stanford.edu/data/media/public/oldreports/sep38/38_29.pdf
roll_idx[1:] = roll_idx[1:][::-1]
# N // 2 length factors, rather than N with indexing as seen in FFT
factor_cos = np.cos(2 * np.pi * np.arange(N // 2) / N)
factor_sin = np.sin(2 * np.pi * np.arange(N // 2) / N)
return np.concatenate([X_even + factor_cos * X_odd + factor_sin * X_odd[roll_idx],
X_even - factor_cos * X_odd - factor_sin * X_odd[roll_idx]])
def ifht_basic(X):
X = np.asarray(X)
n = len(X)
x = fht_basic(X)
x = 1./n * x
return x
_dht_binary_cache = {}
def dht_slow_binary(x):
x = np.asarray(x)
N = x.shape[0]
if N not in _dft_cache:
n = np.arange(N)
k = n.reshape((N, 1))
M = cas(2 * np.pi * k * n / N)
else:
M = _dht_binary_cache[N]
# equivalent to dht specifically for binary data
# may not be faster unless very sparse
return M[:, np.where(x)[0]].sum(axis=1)
def fht_basic_binary(x):
x = np.asarray(x)
N = x.shape[0]
if N % 2 > 0:
raise ValueError("size of x must be a power of 2")
elif N <= 32: # this cutoff should be optimized
return dht_slow_binary(x)
else:
X_even = fht_basic_binary(x[::2])
X_odd = fht_basic_binary(x[1::2])
roll_idx = [el for el in range(len(X_odd))]
# 0 el same, rest is reversed
# shown as (variably (N - k), or a with bar over it)
# https://caxapa.ru/thumbs/455725/algorithms.pdf
# pg 332 http://sep.stanford.edu/data/media/public/oldreports/sep38/38_29.pdf
roll_idx[1:] = roll_idx[1:][::-1]
# N // 2 length factors, rather than N with indexing as seen in FFT
factor_cos = np.cos(2 * np.pi * np.arange(N // 2) / N)
factor_sin = np.sin(2 * np.pi * np.arange(N // 2) / N)
return np.concatenate([X_even + factor_cos * X_odd + factor_sin * X_odd[roll_idx],
X_even - factor_cos * X_odd - factor_sin * X_odd[roll_idx]])
def fht_conv_real_binary(x, y):
X = fht_basic(x)
Y = fht_basic_binary(y)
Xflip = np.roll(np.flip(X), shift=1) # change is here
Yflip = np.roll(np.flip(Y), shift=1) # and here only
Yplus = Y + Yflip
Yminus = Y - Yflip
Z = 0.5 * (X * Yplus + Xflip * Yminus)
z = ifht_basic(Z)
return z
def fht_conv_real_real(x, y):
X = fht_basic(x)
Y = fht_basic(y)
Xflip = np.roll(np.flip(X), shift=1)
Yflip = np.roll(np.flip(Y), shift=1)
Yplus = Y + Yflip
Yminus = Y - Yflip
Z = 0.5 * (X * Yplus + Xflip * Yminus)
z = ifht_basic(Z)
return z
def fft_conv_real_real(x, y):
X = np.fft.fft(x)
Y = np.fft.fft(y)
return np.fft.ifft(X * Y).real
def binary_string_search(s, p):
# will do padding internally
alg = "fht"
assert s.dtype == bool
assert p.dtype == bool
# need the complement of both strings for this
p_c = ~p
s_c = ~s
# pad to N + M - 1
t_len = (len(p) + len(s) - 1)
# force the conv len to be >= 32 due to internal impl of the fht
if t_len < 32:
t_len = 32
# use 32 as the minimum block size due to fht impl details...
if t_len % 32 != 0:
t_len = t_len + (32 - t_len % 32)
if (t_len % 32) != 0:
raise ValueError("t_len should be % 32 after pad, was {}. Debug".format(t_len))
s = np.concatenate((s, np.array([0] * (t_len - len(s))))).astype("bool")
s_c = np.concatenate((s_c, np.array([0] * (t_len - len(s_c))))).astype("bool")
# this flip pre-specified
# https://cs6505.wordpress.com/dc-iii-pattern-matching/
# https://codeforces.com/topic/59717/en3
p_r = np.concatenate((p[::-1], np.array([0] * (t_len - len(p))))).astype("bool")
p_r_c = np.concatenate((p_c[::-1], np.array([0] * (t_len - len(p_c))))).astype("bool")
if alg == "fht":
# doing the conv with fht
rr = fht_conv_real_real(s.astype("float32"), p_r.astype("float32"))
rr_c = fht_conv_real_real(s_c.astype("float32"), p_r_c.astype("float32"))
# remember that convolution starts with shifting original match, need to shift final index
match_index_fht = np.argmax(rr_c + rr) - (len(p) - 1)
match_sums = rr_c + rr
match_index = match_index_fht
elif alg == "fft":
# doing the conv with fft
rr = fft_conv_real_real(s.astype("float32"), p_r.astype("float32"))
rr_c = fft_conv_real_real(s_c.astype("float32"), p_r_c.astype("float32"))
# remember that convolution starts with shifting original match, need to shift final index
match_index_fft = np.argmax(rr_c + rr) - (len(p) - 1)
match_sums = rr_c + rr
match_index = match_index_fft
else:
raise ValueError("Conv alg unknown")
return match_index, match_sums
def full_string_search(s, p, vocab="auto"):
# will do padding internally in the binary string search
if vocab == "auto":
a_v = sorted(list(set(s + p)))
else:
raise ValueError("Need to manually pass vocab size")
all_s = []
all_p = []
all_res = []
all_res_sum = []
for v_i in a_v:
# make binary strings for each word in the vocab
s_a = [True if v_i == s_i else False for s_i in s]
p_a = [True if v_i == p_i else False for p_i in p]
s_a = np.array(s_a).astype("bool")
p_a = np.array(p_a).astype("bool")
res, res_sum = binary_string_search(s_a, p_a)
all_s.append(s_a)
all_p.append(p_a)
all_res.append(res)
all_res_sum.append(res_sum)
# want this to be len(p) * size of vocab
best_match_val = np.array(all_res_sum)
best_match_idx = best_match_val.sum(axis=0).argmax() - (len(p) - 1)
return best_match_idx, best_match_val
def full_string_wildcard_search(s, p, vocab="auto"):
# wildcard (can also do standard matching) using alg from
# http://www.cs.cmu.edu/afs/cs/academic/class/15750-s16/Handouts/WildCards2006.pdf
# will do padding internally for fft/fht, means the trailing 0s are meaningless
alg = "fht"
if vocab == "auto":
a_v = sorted(list(set(s + p)))
else:
raise ValueError("Need to manually pass vocab size")
lu = {}
lu["*"] = 0
_i = 1
for a_v_i in a_v:
if a_v_i == "*":
continue
lu[a_v_i] = _i
_i += 1
r_lu = {v: k for k, v in lu.items()}
in_s = s
in_p = p
s = [lu[s_i] for s_i in s]
p = [lu[p_i] for p_i in p]
# pad to N + M - 1
t_len = (len(p) + len(s) - 1)
# force the conv len to be >= 32 due to internal impl of the fht
if t_len < 32:
t_len = 32
# use 32 as the minimum block size due to fht impl details...
if t_len % 32 != 0:
t_len = t_len + (32 - t_len % 32)
if (t_len % 32) != 0:
raise ValueError("t_len should be % 32 after pad, was {}. Debug".format(t_len))
s = np.concatenate((s, np.array([0] * (t_len - len(s)))))
# this flip pre-specified
# https://cs6505.wordpress.com/dc-iii-pattern-matching/
# https://codeforces.com/topic/59717/en3
p = np.concatenate((p[::-1], np.array([0] * (t_len - len(p)))))
# can do it for non wildcard with other formula using binary strings
# in theory faster for small vocabularies
if alg == "fht":
conv = fht_conv_real_real
elif alg == "fft":
conv = fft_conv_real_real
# p^3 * t
term_1 = conv(p.astype("float32") ** 3, s.astype("float32"))
# p ^ 2 * t ^ 2
term_2 = conv(p.astype("float32") ** 2, s.astype("float32") ** 2)
# p * t ^ 3
term_3 = conv(p.astype("float32"), s.astype("float32") ** 3)
A = term_1 - 2 * term_2 + term_3
A = np.round(A)
matches = np.where(A == 0)[0]
if len(matches) > 0:
if matches[0] - (len(in_p) - 1) < len(in_s):
best_match_idx = matches[0] - (len(in_p) - 1)
else:
best_match_idx = -1
else:
best_match_idx = -1
best_match_val = A
return best_match_idx, best_match_val
# example of binary case
# checking that the match moves correctly
s = np.array([1, 1, 0, 1, 0, 1, 0, 0]).astype("bool")
p = np.array([1, 1, 0]).astype("bool")
res1, res1_sum = binary_string_search(s, p)
assert res1 == 0
s = np.array([1, 0, 0, 1, 0, 1, 1, 0]).astype("bool")
p = np.array([1, 1, 0]).astype("bool")
res2, res2_sum = binary_string_search(s, p)
assert res2 == 5
s = np.array([1, 0, 1, 1, 0, 1, 0, 1]).astype("bool")
p = np.array([1, 1, 0]).astype("bool")
res3, res3_sum = binary_string_search(s, p)
assert res3 == 2
# if the sum is not == len(p), there was no matching substring
s = np.array([1, 0, 1, 0, 0, 1, 0, 1]).astype("bool")
p = np.array([1, 1, 0]).astype("bool")
res4, res4_sum = binary_string_search(s, p)
# round in case of fp issues
assert np.round(np.max(res4_sum)) != 3
# double check failed match
assert np.any(s[res4:res4+len(p)] != p)
# can be multiple matches, not handling it here
# sum would be vocabulary size * len query
s = "hello my darling, hello my baby, hello my ragtime gal"
p = "rag"
res5, res5_sum = full_string_search(s, p)
assert res5 == 42
s = "hello my darling, hello my baby, hello my ragtime gal"
p = "dar"
res6, res6_sum = full_string_search(s, p)
assert res6 == 9
s = "hello my darling, hello my baby, hello my ragtime gal"
p = "d*r"
res7, res7_sum = full_string_wildcard_search(s, p)
assert res7 == 9
# bool test for fht and fht
t2 = np.array([1, 0, 0, 1, 1, 1, 0, 1, 0, 1] * 8).astype("bool")
r2 = dht_slow_binary(t2)
r3 = dht_slow(t2.astype("float32"))
t2_rec1 = idht_slow(r2)
print("binary dht matches standard dht:", np.allclose(r2, r3))
print("dht -> inverse dht matches input:", np.allclose(t2_rec1, t2))
r4 = fht_basic_binary(t2)
t2_rec2 = ifht_basic(r4)
print("fht -> inverse fht matches input:", np.allclose(t2_rec2, t2))
paired_transforms = []
paired_transforms.append(("direct dft", dft_slow, idft_slow))
paired_transforms.append(("direct fft", fft_basic, ifft_basic))
paired_transforms.append(("direct vectorized fft", fft_vectorized, ifft_vectorized))
paired_transforms.append(("dht, npfft based", dht_by_npfft, idht_by_npfft))
paired_transforms.append(("dht, slow", dht_slow, idht_slow))
paired_transforms.append(("dht, basic", fht_basic, ifht_basic))
t = np.array([1, 7, 3, 4, 2, 5, 1, 9] * 8).astype("float32")
for name, fwd, bwd in paired_transforms:
r_t = bwd(fwd(t))
passed = np.allclose(r_t, t)
if passed:
print("{} ok: ".format(name), passed)
if not passed:
print(r_t)
match_transforms = []
match_transforms.append(("slow dht <-> dht with npfft", dht_slow, dht_by_npfft))
match_transforms.append(("slow dht <-> fht basic", dht_slow, fht_basic))
for name, alg1, alg2 in match_transforms:
match_baseline_dht = np.allclose(alg1(t), alg2(t))
if match_baseline_dht:
print("{} ok: ".format(name), match_baseline_dht)
if not match_baseline_dht:
a = alg1(t)
b = alg2(t)
print(a)
print(b)
print(np.abs(a - b))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment