Last active
December 8, 2022 04:41
-
-
Save kastnerkyle/617637bb51c14d928e2f201d838f8498 to your computer and use it in GitHub Desktop.
String matching using fft / fht (hartley transform). For learning, not optimal speed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Kyle Kastner | |
# BSD 3-Clause | |
# Thanks to jakevdp for the nice blog post on FFT | |
# https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/ | |
# Summary | |
# http://www.arazim-project.com/sites/default/files/public/lesson_sums/1fft.pdf | |
# Details on hartley and many xforms | |
# https://caxapa.ru/thumbs/455725/algorithms.pdf | |
# pg 332 http://sep.stanford.edu/data/media/public/oldreports/sep38/38_29.pdf | |
# https://dsp.stackexchange.com/questions/22320/implement-fast-hartley-transform | |
# https://stackoverflow.com/questions/67940011/what-is-wrong-in-my-code-using-dht-to-compute-convolution | |
# https://www.embedded.com/doing-hartley-smartly/ | |
# http://www.theradixpoint.com/hartley/hartley.html | |
# http://www.cs.cmu.edu/afs/cs/academic/class/15750-s16/Handouts/WildCards2006.pdf | |
import numpy as np | |
_dft_cache = {} | |
def dft_slow(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N not in _dft_cache: | |
n = np.arange(N) | |
k = n.reshape((N, 1)) | |
M = np.exp(-2j * np.pi * k * n / N) | |
else: | |
M = _dft_cache[N] | |
return np.dot(M, x) | |
_idft_cache = {} | |
def idft_slow(x): | |
# complex data in | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N not in _idft_cache: | |
n = np.arange(N) | |
k = n.reshape((N, 1)) | |
M = np.exp(2j * np.pi * k * n / N) | |
else: | |
M = _idft_cache[N] | |
r = np.dot(M, x) | |
return 1./N * r | |
def fft_basic(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N % 2 > 0: | |
raise ValueError("size of x must be a power of 2") | |
elif N <= 32: # this cutoff should be optimized | |
return dft_slow(x) | |
else: | |
X_even = fft_basic(x[::2]) | |
X_odd = fft_basic(x[1::2]) | |
factor = np.exp(-2j * np.pi * np.arange(N) / N) | |
return np.concatenate([X_even + factor[:int(N // 2)] * X_odd, | |
X_even + factor[int(N // 2):] * X_odd]) | |
def ifft_basic(x): | |
x = np.asarray(x) | |
x_con = np.conjugate(x) | |
X = fft_basic(x_con) | |
X = np.conjugate(X) | |
X = X / x_con.shape[0] | |
return X | |
def fft_vectorized(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if np.log2(N) % 1 > 0: | |
raise ValueError("size of x must be a power of 2") | |
# N_min here is equivalent to the stopping condition above, | |
# and should be a power of 2 | |
N_min = min(N, 32) | |
# Perform an O[N^2] DFT on all length-N_min sub-problems at once | |
n = np.arange(N_min) | |
k = n[:, None] | |
M = np.exp(-2j * np.pi * n * k / N_min) | |
X = np.dot(M, x.reshape((N_min, -1))) | |
# build-up each level of the recursive calculation all at once | |
while X.shape[0] < N: | |
X_even = X[:, :int(X.shape[1] // 2)] | |
X_odd = X[:, int(X.shape[1] // 2):] | |
factor = np.exp(-1j * np.pi * np.arange(X.shape[0]) / X.shape[0])[:, None] | |
X = np.vstack([X_even + factor * X_odd, X_even - factor * X_odd]) | |
return X.ravel() | |
def ifft_vectorized(x): | |
x = np.asarray(x) | |
x_con = np.conjugate(x) | |
X = fft_vectorized(x_con) | |
X = np.conjugate(X) | |
X = X / x_con.shape[0] | |
return X | |
def dht_by_npfft(x): | |
X = np.fft.fft(x) | |
X = np.real(X) - np.imag(X) | |
return X | |
def idht_by_npfft(X): | |
n = len(X) | |
X = dht_by_npfft(X) | |
X = 1./n * X | |
return X | |
def cas(x): | |
return np.sin(x) + np.cos(x) | |
_dht_cache = {} | |
def dht_slow(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N not in _dft_cache: | |
n = np.arange(N) | |
k = n.reshape((N, 1)) | |
M = cas(2 * np.pi * k * n / N) | |
else: | |
M = _dht_cache[N] | |
return np.dot(M, x) | |
_idht_cache = {} | |
def idht_slow(X): | |
n = len(X) | |
X = dht_slow(X) | |
X = 1./n * X | |
return X | |
def fht_basic(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N % 2 > 0: | |
raise ValueError("size of x must be a power of 2") | |
elif N <= 32: # this cutoff should be optimized | |
return dht_slow(x) | |
else: | |
X_even = fht_basic(x[::2]) | |
X_odd = fht_basic(x[1::2]) | |
roll_idx = [el for el in range(len(X_odd))] | |
# 0 el same, rest is reversed | |
# shown as (variably (N - k), or a with bar over it) | |
# https://caxapa.ru/thumbs/455725/algorithms.pdf | |
# pg 332 http://sep.stanford.edu/data/media/public/oldreports/sep38/38_29.pdf | |
roll_idx[1:] = roll_idx[1:][::-1] | |
# N // 2 length factors, rather than N with indexing as seen in FFT | |
factor_cos = np.cos(2 * np.pi * np.arange(N // 2) / N) | |
factor_sin = np.sin(2 * np.pi * np.arange(N // 2) / N) | |
return np.concatenate([X_even + factor_cos * X_odd + factor_sin * X_odd[roll_idx], | |
X_even - factor_cos * X_odd - factor_sin * X_odd[roll_idx]]) | |
def ifht_basic(X): | |
X = np.asarray(X) | |
n = len(X) | |
x = fht_basic(X) | |
x = 1./n * x | |
return x | |
_dht_binary_cache = {} | |
def dht_slow_binary(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N not in _dft_cache: | |
n = np.arange(N) | |
k = n.reshape((N, 1)) | |
M = cas(2 * np.pi * k * n / N) | |
else: | |
M = _dht_binary_cache[N] | |
# equivalent to dht specifically for binary data | |
# may not be faster unless very sparse | |
return M[:, np.where(x)[0]].sum(axis=1) | |
def fht_basic_binary(x): | |
x = np.asarray(x) | |
N = x.shape[0] | |
if N % 2 > 0: | |
raise ValueError("size of x must be a power of 2") | |
elif N <= 32: # this cutoff should be optimized | |
return dht_slow_binary(x) | |
else: | |
X_even = fht_basic_binary(x[::2]) | |
X_odd = fht_basic_binary(x[1::2]) | |
roll_idx = [el for el in range(len(X_odd))] | |
# 0 el same, rest is reversed | |
# shown as (variably (N - k), or a with bar over it) | |
# https://caxapa.ru/thumbs/455725/algorithms.pdf | |
# pg 332 http://sep.stanford.edu/data/media/public/oldreports/sep38/38_29.pdf | |
roll_idx[1:] = roll_idx[1:][::-1] | |
# N // 2 length factors, rather than N with indexing as seen in FFT | |
factor_cos = np.cos(2 * np.pi * np.arange(N // 2) / N) | |
factor_sin = np.sin(2 * np.pi * np.arange(N // 2) / N) | |
return np.concatenate([X_even + factor_cos * X_odd + factor_sin * X_odd[roll_idx], | |
X_even - factor_cos * X_odd - factor_sin * X_odd[roll_idx]]) | |
def fht_conv_real_binary(x, y): | |
X = fht_basic(x) | |
Y = fht_basic_binary(y) | |
Xflip = np.roll(np.flip(X), shift=1) # change is here | |
Yflip = np.roll(np.flip(Y), shift=1) # and here only | |
Yplus = Y + Yflip | |
Yminus = Y - Yflip | |
Z = 0.5 * (X * Yplus + Xflip * Yminus) | |
z = ifht_basic(Z) | |
return z | |
def fht_conv_real_real(x, y): | |
X = fht_basic(x) | |
Y = fht_basic(y) | |
Xflip = np.roll(np.flip(X), shift=1) | |
Yflip = np.roll(np.flip(Y), shift=1) | |
Yplus = Y + Yflip | |
Yminus = Y - Yflip | |
Z = 0.5 * (X * Yplus + Xflip * Yminus) | |
z = ifht_basic(Z) | |
return z | |
def fft_conv_real_real(x, y): | |
X = np.fft.fft(x) | |
Y = np.fft.fft(y) | |
return np.fft.ifft(X * Y).real | |
def binary_string_search(s, p): | |
# will do padding internally | |
alg = "fht" | |
assert s.dtype == bool | |
assert p.dtype == bool | |
# need the complement of both strings for this | |
p_c = ~p | |
s_c = ~s | |
# pad to N + M - 1 | |
t_len = (len(p) + len(s) - 1) | |
# force the conv len to be >= 32 due to internal impl of the fht | |
if t_len < 32: | |
t_len = 32 | |
# use 32 as the minimum block size due to fht impl details... | |
if t_len % 32 != 0: | |
t_len = t_len + (32 - t_len % 32) | |
if (t_len % 32) != 0: | |
raise ValueError("t_len should be % 32 after pad, was {}. Debug".format(t_len)) | |
s = np.concatenate((s, np.array([0] * (t_len - len(s))))).astype("bool") | |
s_c = np.concatenate((s_c, np.array([0] * (t_len - len(s_c))))).astype("bool") | |
# this flip pre-specified | |
# https://cs6505.wordpress.com/dc-iii-pattern-matching/ | |
# https://codeforces.com/topic/59717/en3 | |
p_r = np.concatenate((p[::-1], np.array([0] * (t_len - len(p))))).astype("bool") | |
p_r_c = np.concatenate((p_c[::-1], np.array([0] * (t_len - len(p_c))))).astype("bool") | |
if alg == "fht": | |
# doing the conv with fht | |
rr = fht_conv_real_real(s.astype("float32"), p_r.astype("float32")) | |
rr_c = fht_conv_real_real(s_c.astype("float32"), p_r_c.astype("float32")) | |
# remember that convolution starts with shifting original match, need to shift final index | |
match_index_fht = np.argmax(rr_c + rr) - (len(p) - 1) | |
match_sums = rr_c + rr | |
match_index = match_index_fht | |
elif alg == "fft": | |
# doing the conv with fft | |
rr = fft_conv_real_real(s.astype("float32"), p_r.astype("float32")) | |
rr_c = fft_conv_real_real(s_c.astype("float32"), p_r_c.astype("float32")) | |
# remember that convolution starts with shifting original match, need to shift final index | |
match_index_fft = np.argmax(rr_c + rr) - (len(p) - 1) | |
match_sums = rr_c + rr | |
match_index = match_index_fft | |
else: | |
raise ValueError("Conv alg unknown") | |
return match_index, match_sums | |
def full_string_search(s, p, vocab="auto"): | |
# will do padding internally in the binary string search | |
if vocab == "auto": | |
a_v = sorted(list(set(s + p))) | |
else: | |
raise ValueError("Need to manually pass vocab size") | |
all_s = [] | |
all_p = [] | |
all_res = [] | |
all_res_sum = [] | |
for v_i in a_v: | |
# make binary strings for each word in the vocab | |
s_a = [True if v_i == s_i else False for s_i in s] | |
p_a = [True if v_i == p_i else False for p_i in p] | |
s_a = np.array(s_a).astype("bool") | |
p_a = np.array(p_a).astype("bool") | |
res, res_sum = binary_string_search(s_a, p_a) | |
all_s.append(s_a) | |
all_p.append(p_a) | |
all_res.append(res) | |
all_res_sum.append(res_sum) | |
# want this to be len(p) * size of vocab | |
best_match_val = np.array(all_res_sum) | |
best_match_idx = best_match_val.sum(axis=0).argmax() - (len(p) - 1) | |
return best_match_idx, best_match_val | |
def full_string_wildcard_search(s, p, vocab="auto"): | |
# wildcard (can also do standard matching) using alg from | |
# http://www.cs.cmu.edu/afs/cs/academic/class/15750-s16/Handouts/WildCards2006.pdf | |
# will do padding internally for fft/fht, means the trailing 0s are meaningless | |
alg = "fht" | |
if vocab == "auto": | |
a_v = sorted(list(set(s + p))) | |
else: | |
raise ValueError("Need to manually pass vocab size") | |
lu = {} | |
lu["*"] = 0 | |
_i = 1 | |
for a_v_i in a_v: | |
if a_v_i == "*": | |
continue | |
lu[a_v_i] = _i | |
_i += 1 | |
r_lu = {v: k for k, v in lu.items()} | |
in_s = s | |
in_p = p | |
s = [lu[s_i] for s_i in s] | |
p = [lu[p_i] for p_i in p] | |
# pad to N + M - 1 | |
t_len = (len(p) + len(s) - 1) | |
# force the conv len to be >= 32 due to internal impl of the fht | |
if t_len < 32: | |
t_len = 32 | |
# use 32 as the minimum block size due to fht impl details... | |
if t_len % 32 != 0: | |
t_len = t_len + (32 - t_len % 32) | |
if (t_len % 32) != 0: | |
raise ValueError("t_len should be % 32 after pad, was {}. Debug".format(t_len)) | |
s = np.concatenate((s, np.array([0] * (t_len - len(s))))) | |
# this flip pre-specified | |
# https://cs6505.wordpress.com/dc-iii-pattern-matching/ | |
# https://codeforces.com/topic/59717/en3 | |
p = np.concatenate((p[::-1], np.array([0] * (t_len - len(p))))) | |
# can do it for non wildcard with other formula using binary strings | |
# in theory faster for small vocabularies | |
if alg == "fht": | |
conv = fht_conv_real_real | |
elif alg == "fft": | |
conv = fft_conv_real_real | |
# p^3 * t | |
term_1 = conv(p.astype("float32") ** 3, s.astype("float32")) | |
# p ^ 2 * t ^ 2 | |
term_2 = conv(p.astype("float32") ** 2, s.astype("float32") ** 2) | |
# p * t ^ 3 | |
term_3 = conv(p.astype("float32"), s.astype("float32") ** 3) | |
A = term_1 - 2 * term_2 + term_3 | |
A = np.round(A) | |
matches = np.where(A == 0)[0] | |
if len(matches) > 0: | |
if matches[0] - (len(in_p) - 1) < len(in_s): | |
best_match_idx = matches[0] - (len(in_p) - 1) | |
else: | |
best_match_idx = -1 | |
else: | |
best_match_idx = -1 | |
best_match_val = A | |
return best_match_idx, best_match_val | |
# example of binary case | |
# checking that the match moves correctly | |
s = np.array([1, 1, 0, 1, 0, 1, 0, 0]).astype("bool") | |
p = np.array([1, 1, 0]).astype("bool") | |
res1, res1_sum = binary_string_search(s, p) | |
assert res1 == 0 | |
s = np.array([1, 0, 0, 1, 0, 1, 1, 0]).astype("bool") | |
p = np.array([1, 1, 0]).astype("bool") | |
res2, res2_sum = binary_string_search(s, p) | |
assert res2 == 5 | |
s = np.array([1, 0, 1, 1, 0, 1, 0, 1]).astype("bool") | |
p = np.array([1, 1, 0]).astype("bool") | |
res3, res3_sum = binary_string_search(s, p) | |
assert res3 == 2 | |
# if the sum is not == len(p), there was no matching substring | |
s = np.array([1, 0, 1, 0, 0, 1, 0, 1]).astype("bool") | |
p = np.array([1, 1, 0]).astype("bool") | |
res4, res4_sum = binary_string_search(s, p) | |
# round in case of fp issues | |
assert np.round(np.max(res4_sum)) != 3 | |
# double check failed match | |
assert np.any(s[res4:res4+len(p)] != p) | |
# can be multiple matches, not handling it here | |
# sum would be vocabulary size * len query | |
s = "hello my darling, hello my baby, hello my ragtime gal" | |
p = "rag" | |
res5, res5_sum = full_string_search(s, p) | |
assert res5 == 42 | |
s = "hello my darling, hello my baby, hello my ragtime gal" | |
p = "dar" | |
res6, res6_sum = full_string_search(s, p) | |
assert res6 == 9 | |
s = "hello my darling, hello my baby, hello my ragtime gal" | |
p = "d*r" | |
res7, res7_sum = full_string_wildcard_search(s, p) | |
assert res7 == 9 | |
# bool test for fht and fht | |
t2 = np.array([1, 0, 0, 1, 1, 1, 0, 1, 0, 1] * 8).astype("bool") | |
r2 = dht_slow_binary(t2) | |
r3 = dht_slow(t2.astype("float32")) | |
t2_rec1 = idht_slow(r2) | |
print("binary dht matches standard dht:", np.allclose(r2, r3)) | |
print("dht -> inverse dht matches input:", np.allclose(t2_rec1, t2)) | |
r4 = fht_basic_binary(t2) | |
t2_rec2 = ifht_basic(r4) | |
print("fht -> inverse fht matches input:", np.allclose(t2_rec2, t2)) | |
paired_transforms = [] | |
paired_transforms.append(("direct dft", dft_slow, idft_slow)) | |
paired_transforms.append(("direct fft", fft_basic, ifft_basic)) | |
paired_transforms.append(("direct vectorized fft", fft_vectorized, ifft_vectorized)) | |
paired_transforms.append(("dht, npfft based", dht_by_npfft, idht_by_npfft)) | |
paired_transforms.append(("dht, slow", dht_slow, idht_slow)) | |
paired_transforms.append(("dht, basic", fht_basic, ifht_basic)) | |
t = np.array([1, 7, 3, 4, 2, 5, 1, 9] * 8).astype("float32") | |
for name, fwd, bwd in paired_transforms: | |
r_t = bwd(fwd(t)) | |
passed = np.allclose(r_t, t) | |
if passed: | |
print("{} ok: ".format(name), passed) | |
if not passed: | |
print(r_t) | |
match_transforms = [] | |
match_transforms.append(("slow dht <-> dht with npfft", dht_slow, dht_by_npfft)) | |
match_transforms.append(("slow dht <-> fht basic", dht_slow, fht_basic)) | |
for name, alg1, alg2 in match_transforms: | |
match_baseline_dht = np.allclose(alg1(t), alg2(t)) | |
if match_baseline_dht: | |
print("{} ok: ".format(name), match_baseline_dht) | |
if not match_baseline_dht: | |
a = alg1(t) | |
b = alg2(t) | |
print(a) | |
print(b) | |
print(np.abs(a - b)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment