Skip to content

Instantly share code, notes, and snippets.

@viswanathgs
Created January 15, 2020 18:52
Show Gist options
  • Save viswanathgs/2ec3e1e82fab9c90748bae0bafa5e439 to your computer and use it in GitHub Desktop.
Save viswanathgs/2ec3e1e82fab9c90748bae0bafa5e439 to your computer and use it in GitHub Desktop.
class Trie:
def __init__(self, character, prob):
self.character = character
self.probability = prob
self.children = {}
def add_child(self, child):
assert self.character != '$', "Cannot add child to end token"
if child.character not in self.children:
self.children[child.character] = child
assert child.probability == self.children[child.character].probability, \
"Child already exists"
return self.children[child.character]
def dfs(self, prefix='', prefix_prob=1.0):
new_prefix = prefix + self.character
new_prob = prefix_prob * self.probability
if self.character == '$':
return [(new_prefix, new_prob)]
results = []
for child in self.children.values():
results += child.dfs(new_prefix, new_prob)
return results
@staticmethod
def make_trie(raw_list):
root = Trie('', 1.0)
for item in raw_list:
node = root
for character, prob in item:
node = node.add_child(Trie(character, prob))
return root
# Incorrect impl
class BeamSearch:
# Init takes Trie's root
def __init__(self, trie_root):
self.trie_root = trie_root
# Select k decodings with maximum probabilities
def max_k(self, decodings, k):
return sorted(decodings, key=lambda x: x['probability'], reverse=True)[:k]
# Apply beam search using recurrence with beam size of k
# Decodings at each recurrence instance represent currently
# selected beam outputs
def select_k_recur(self, decodings, k):
all_decodings = []
# Generate new |k X 29| beams
for decoding in decodings:
for child in decoding['node'].children.values():
current_decoding = {
'output': decoding['output'] + child.character,
'node': child,
'probability': decoding['probability'] * child.probability
}
all_decodings.append(current_decoding)
# Select new k candidates from all of the possible decodings above
selected_decodings = self.max_k(all_decodings, k)
final_k = k
filtered_decodings = []
# Filter out finished beams which reached end
for decoding in selected_decodings:
if decoding['output'][-1] != '$':
filtered_decodings.append(decoding)
else:
# Decrease k's value accordingly and append
# current decoding to finalized decodings
self.final_decodings.append(decoding)
final_k -= 1
if final_k != 0:
self.select_k_recur(filtered_decodings, final_k)
def k_step_decoding(self, trie_root, k):
current_node = trie_root
# All first decodings
current_decodings = [{'output': trie_root.character + child.character,
'node': child,
'probability': child.probability}
for child in current_node.children.values()]
# Select top k decodings with max probabilities
filtered_decodings = self.max_k(current_decodings, k)
self.final_decodings = []
# Apply beam search
self.select_k_recur(filtered_decodings, k)
# Select the output with maximum probability among
# all final k beams that we got
return self.max_k(self.final_decodings, 1)[0]['output']
class BeamSearch2:
# Init takes Trie's root
def __init__(self, trie_root):
self.trie_root = trie_root
# Select k decodings with maximum probabilities
def max_k(self, decodings, k):
return sorted(decodings, key=lambda x: x['probability'], reverse=True)[:k]
# Apply beam search using recurrence with beam size of k
# Decodings at each recurrence instance represent currently
# selected beam outputs
def select_k_recur(self, decodings, k):
# Keep track of already finished decodings
all_decodings = [d for d in decodings if d['node'].character == '$']
# Generate new |k X 29| beams
terminate = True
for decoding in decodings:
for child in decoding['node'].children.values():
current_decoding = {
'output': decoding['output'] + child.character,
'node': child,
'probability': decoding['probability'] * child.probability
}
all_decodings.append(current_decoding)
terminate = False # New nodes added, need to recur
# Select new k candidates from all of the possible decodings above
selected_decodings = self.max_k(all_decodings, k)
if terminate:
return selected_decodings
else:
return self.select_k_recur(selected_decodings, k)
def k_step_decoding(self, trie_root, k):
init_decodings = [{
'output': '',
'node': trie_root,
'probability': 1.0,
}]
final_decodings = self.select_k_recur(init_decodings, k)
return self.max_k(final_decodings, 1)[0]['output']
if __name__ == '__main__':
data = [
[('c', 1.0), ('$', 0.1)],
[('c', 1.0), ('a', 1.0), ('t', 0.7), ('$', 1.0)],
[('c', 1.0), ('a', 1.0), ('s', 0.8), ('$', 0.2)],
]
root = Trie.make_trie(data)
k = 2
print("Brute force:")
print(sorted(root.dfs(), key=lambda (chr, prob): prob, reverse=True))
print("BeamSearch1:")
print(BeamSearch(root).k_step_decoding(root, k))
print("BeamSearch2:")
print(BeamSearch2(root).k_step_decoding(root, k))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment