Skip to content

Instantly share code, notes, and snippets.

@terapyon
Forked from Arachnid/automata.py
Created July 22, 2012 04:08
Show Gist options
  • Save terapyon/3158352 to your computer and use it in GitHub Desktop.
Save terapyon/3158352 to your computer and use it in GitHub Desktop.
import bisect
class NFA(object):
EPSILON = object()
ANY = object()
def __init__(self, start_state):
self.transitions = {}
self.final_states = set()
self._start_state = start_state
@property
def start_state(self):
return frozenset(self._expand(set([self._start_state])))
def add_transition(self, src, input, dest):
self.transitions.setdefault(src, {}).setdefault(input, set()).add(dest)
def add_final_state(self, state):
self.final_states.add(state)
def is_final(self, states):
return self.final_states.intersection(states)
def _expand(self, states):
frontier = set(states)
while frontier:
state = frontier.pop()
new_states = self.transitions.get(state, {}).get(NFA.EPSILON, set()).difference(states)
frontier.update(new_states)
states.update(new_states)
return states
def next_state(self, states, input):
dest_states = set()
for state in states:
state_transitions = self.transitions.get(state, {})
dest_states.update(state_transitions.get(input, []))
dest_states.update(state_transitions.get(NFA.ANY, []))
return frozenset(self._expand(dest_states))
def get_inputs(self, states):
inputs = set()
for state in states:
inputs.update(self.transitions.get(state, {}).keys())
return inputs
def to_dfa(self):
dfa = DFA(self.start_state)
frontier = [self.start_state]
seen = set()
while frontier:
current = frontier.pop()
inputs = self.get_inputs(current)
for input in inputs:
if input == NFA.EPSILON: continue
new_state = self.next_state(current, input)
if new_state not in seen:
frontier.append(new_state)
seen.add(new_state)
if self.is_final(new_state):
dfa.add_final_state(new_state)
if input == NFA.ANY:
dfa.set_default_transition(current, new_state)
else:
dfa.add_transition(current, input, new_state)
return dfa
class DFA(object):
def __init__(self, start_state):
self.start_state = start_state
self.transitions = {}
self.defaults = {}
self.final_states = set()
def add_transition(self, src, input, dest):
self.transitions.setdefault(src, {})[input] = dest
def set_default_transition(self, src, dest):
self.defaults[src] = dest
def add_final_state(self, state):
self.final_states.add(state)
def is_final(self, state):
return state in self.final_states
def next_state(self, src, input):
state_transitions = self.transitions.get(src, {})
return state_transitions.get(input, self.defaults.get(src, None))
def next_valid_string(self, input):
state = self.start_state
stack = []
# Evaluate the DFA as far as possible
for i, x in enumerate(input):
stack.append((input[:i], state, x))
state = self.next_state(state, x)
if not state: break
else:
stack.append((input[:i+1], state, None))
if self.is_final(state):
# Input word is already valid
return input
# Perform a 'wall following' search for the lexicographically smallest
# accepting state.
while stack:
path, state, x = stack.pop()
x = self.find_next_edge(state, x)
if x:
path += x
state = self.next_state(state, x)
if self.is_final(state):
return path
stack.append((path, state, None))
return None
def find_next_edge(self, s, x):
if x is None:
x = u'\0'
else:
x = unichr(ord(x) + 1)
state_transitions = self.transitions.get(s, {})
if x in state_transitions or s in self.defaults:
return x
labels = sorted(state_transitions.keys())
pos = bisect.bisect_left(labels, x)
if pos < len(labels):
return labels[pos]
return None
def levenshtein_automata(term, k):
nfa = NFA((0, 0))
for i, c in enumerate(term):
for e in range(k + 1):
# Correct character
nfa.add_transition((i, e), c, (i + 1, e))
if e < k:
# Deletion
nfa.add_transition((i, e), NFA.ANY, (i, e + 1))
# Insertion
nfa.add_transition((i, e), NFA.EPSILON, (i + 1, e + 1))
# Substitution
nfa.add_transition((i, e), NFA.ANY, (i + 1, e + 1))
for e in range(k + 1):
if e < k:
nfa.add_transition((len(term), e), NFA.ANY, (len(term), e + 1))
nfa.add_final_state((len(term), e))
return nfa
def find_all_matches(word, k, lookup_func):
"""Uses lookup_func to find all words within levenshtein distance k of word.
Args:
word: The word to look up
k: Maximum edit distance
lookup_func: A single argument function that returns the first word in the
database that is greater than or equal to the input argument.
Yields:
Every matching word within levenshtein distance k from the database.
"""
lev = levenshtein_automata(word, k).to_dfa()
match = lev.next_valid_string(u'\0')
while match:
next = lookup_func(match)
if not next:
return
if match == next:
yield match
next = next + u'\0'
match = lev.next_valid_string(next)
# coding: utf-8
import automata
from automata_test import Matcher, get_words
def get_ja_words():
words = get_words(filename='ja_words.txt')
return words
if __name__ == '__main__':
import sys
args = sys.argv
if len(args) > 2:
s = args[1]
k = int(args[2])
elif len(args) > 1:
s = args[1]
k = 1
else:
s = 'food' # for test
k = 1
if not isinstance(s, unicode):
s = s.decode('utf-8')
words = get_ja_words()
m = Matcher(words)
li = list(automata.find_all_matches(s, k, m))
print len(li), ', '.join(li)
print m.probes
import bisect
import random
import automata
class Matcher(object):
def __init__(self, l):
self.l = l
self.probes = 0
def __call__(self, w):
self.probes += 1
pos = bisect.bisect_left(self.l, w)
if pos < len(self.l):
return self.l[pos]
else:
return None
def levenshtein(s1, s2):
if len(s1) < len(s2):
return levenshtein(s2, s1)
if not s1:
return len(s2)
previous_row = xrange(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
deletions = current_row[j] + 1 # than s2
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
class BKNode(object):
def __init__(self, term):
self.term = term
self.children = {}
def insert(self, other):
distance = levenshtein(self.term, other)
if distance in self.children:
self.children[distance].insert(other)
else:
self.children[distance] = BKNode(other)
def search(self, term, k, results=None):
if results is None:
results = []
distance = levenshtein(self.term, term)
counter = 1
if distance <= k:
results.append(self.term)
for i in range(max(0, distance - k), distance + k + 1):
child = self.children.get(i)
if child:
counter += child.search(term, k, results)
return counter
def get_words(filename="/usr/share/dict/web2"):
words = (x.strip().lower().decode('utf-8') for x in open(filename, 'r'))
sorted_words = sorted(words)
# words10 = [x for x in sorted_words if random.random() <= 0.1]
# words100 = [x for x in sorted_words if random.random() <= 0.01]
return sorted_words
if __name__ == '__main__':
import sys
args = sys.argv
if len(args) > 2:
s = args[1]
k = int(args[2])
elif len(args) > 1:
s = args[1]
k = 1
else:
s = 'food' # for test
k = 1
words = get_words()
m = Matcher(words)
#assert len(list(automata.find_all_matches('food', 1, m))) == 18
li = list(automata.find_all_matches(s, k, m))
print len(li), li
print m.probes
#m = Matcher(words)
#assert len(list(automata.find_all_matches('food', 2, m))) == 283
#print m.probes
アセンション島
アンドラ
アラブ首長国連邦
アフガニスタン
アンティグァ=バーブーダ
アンギーラ島
アルバニア
アルメニア
オランダ領アンティル
アンゴラ
アジア太平洋連合
南極
アルゼンチン
アメリカンサモア
オーストリア
オーストラリア
アルバ
アゼルバイジャン
ボスニア=ヘルツェゴビナ
バルバドス
バングラデシュ
ベルギー
ブルキナファソ
ブルガリア
バハレーン
ブルンジ
ベナン
バーミューダ諸島
ブルネイ
ボリビア
ブラジル
バハマ
ブータン
ブーベ島
ボツワナ
ベラルーシ
ベリーズ
カナダ
ココス諸島
中央アフリカ共和国
コンゴ
スイス
コートジボアール(象牙海岸)
クック諸島
チリ
カメルーン
中国
コロンビア
コスタリカ
旧チェコスロバキア
キューバ
カーボベルデ
クリスマス諸島
キプロス
チェコ共和国
ドイツ
ジブチ
デンマーク
ドミニカ連邦
ドミニカ共和国
アルジェリア
エクアドル
エストニア
エジプト
西サハラ
エリトリア
スペイン
エチオピア
欧州連合
フィンランド
フィジー
フォークランド諸島
ミクロネシア連邦
フェロー諸島
フランス
フランス(首都圏?)
ガボン
イギリス
グレナダ
グルジア
フランス領ギアナ
ガーナ
ジブラルタル
グリーンランド
ガンビア
ギニア
グアドループ
赤道ギニア
ギリシャ
サウスジョージア島・サウスサンドイッチ島
グアテマラ
グアム
ギニアビサオ
ガイアナ
香港
ハード・マクドナルド諸島
ホンジュラス
クロアチア
ハイチ
ハンガリー
インドネシア
アイルランド
イスラエル
インド
イギリスインド洋領
イラク
イラン
アイスランド
イタリア
ジャマイカ
ヨルダン
日本
ケニア
キルギスタン
カンボジア
キリバス
コモロ
セントクリストファー・ネイビス
北朝鮮
韓国
クウェート
ケイマン諸島
カザフスタン
ラオス
レバノン
セントルシア
リヒテンシュタイン
スリランカ
リベリア
レソト
リトアニア
ルクセンブルク
ラトビア
リビア
モロッコ
モナコ
モルドバ
モンテネグロ
マダガスカル
マーシャル諸島
マケドニア
マリ
ミャンマー
モンゴル
マカオ
北マリアナ諸島
マルチニーク
モーリタニア
モントセラト
マルタ
モーリシャス
モルディブ
マラウイ
メキシコ
マレーシア
モザンビーク
ナミビア
ニューカレドニア
ニジェール
ノーフォーク諸島
ナイジェリア
ニカラグア
オランダ
ノルウェー
ネパール
ナウル
中立地帯
ニウエ
ニュージーランド
オマーン
パナマ
ペルー
フランス領ポリネシア
パプアニューギニア
フィリピン
パキスタン
ポーランド
サンピエール・ミクロン諸島
ピトケアン諸島
プエルトリコ
パレスチナ
ポルトガル
パラオ
パラグアイ
カタール
レユニオン島
ルーマニア
セルビア共和国
ロシア連邦
ルワンダ
サウジアラビア
ソロモン諸島
セイシェル
スーダン
スウェーデン
シンガポール
セントヘレナ島
スロベニア
スバールバル・ヤンマイエン島
スロバキア
シエラレオネ
サンマリノ
セネガル
ソマリア
スリナム
サントメ・プリンシペ
旧ソビエト連邦
エルサルバドル
シリア
スワジランド
タークス・ナイコス諸島
チャド
フランス南方領土
トーゴ
タイ
タジキスタン
トケラウ
トルクメニスタン
チュニジア
トンガ
東ティモール(インドネシア領)
トルコ
トリニダード・トバゴ
ツバル
台湾
タンザニア
ウクライナ
ウガンダ
イギリス
合衆国周辺離島
アメリカ
ウルグアイ
ウズベキスタン
バチカン市国
セントビンセントおよびグレナディン諸島
ベネズエラ
バージン諸島
バージン諸島
ベトナム
バヌアツ
ワリスフツナ諸島
西サモア
イエメン
マヨット島
ユーゴスラビア
南アフリカ
ザンビア
ザイール
ジンバブエ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment