Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
from dataclasses import astuple, dataclass
from typing import Callable, List, NewType, Set, Tuple
import numpy as np
# Dummy typings; feel free to switch to PyTorch.
Word = str
Phrase = NewType('Phrase', Tuple[Word])
Score = float # or th.float, etc.
State = np.ndarray # or th.Tensor, etc.
class Hypothesis:
"""Class for a node in the beam search beam."""
story: List[Phrase]
candidates: Set[Phrase]
score: Score
state: State
class Model:
def __init__(self):
def initial_state(self) -> State:
raise NotImplementedError
def score(self, next_word: Word, state: State) -> Score:
raise NotImplementedError
def transition(self, next_word: Word, state: State) -> State:
raise NotImplementedError
def score_sequence(self, phrases: List[Phrase]) -> Score:
raise NotImplementedError
class DummyModel(Model):
def initial_state(self) -> State:
return np.random.randn(4)
def score(self, next_word: Word, state: State) -> Score:
return np.random.random()
def transition(self, next_word: Word, state: State) -> State:
return state
def score_sequence(self, phrases: List[Phrase]) -> Score:
return np.random.random()
def dummy_heuristic(remainder_to_process: Set[Phrase]) -> Score:
Just return 0, to show the API.
By comparison, Schmaltz et al. use
'a very simple unigram future cost estimate,
g(R) = sum[i∈R] sum[􏰀w∈xi] log p(w).'
return Score(0.0)
def beam_search(
phrases: List[Phrase], # Noun phrases or phrases containing one token
K: int, # Beam size
g: Callable[[Set[Phrase]], Score],
M = len(phrases) # Number of phrases may not be number of tokens.
beams: List[Hypothesis] = [[] for _ in range(M + 1)]
beams[0] = [Hypothesis([], phrases, 0.0, model.initial_state())]
m = 0
for m in range(M): # for all lengths:
for k in range(len(beams[m])): # for each hypothesis at this point:
these_beams = beams[m]
hypothesis = these_beams[k]
(story, candidates, score, state) = astuple(hypothesis)
for phrase in candidates:
new_score, new_state = score, state
for word in phrase:
new_score += model.score(word, new_state)
new_state = model.transition(word, new_state)
j = m + len(phrase)
new_hypothesis = Hypothesis(story + [phrase], candidates - {phrase}, new_score, new_state)
# Extract top K from model.
ordered = sorted(beams[j], key=lambda hyp: model.score_sequence(hyp.story) + g(hyp.candidates))
beams[j] = ordered[:K]
return beams
if __name__ == '__main__':
from pprint import pprint # Pretty-print
import random
model = DummyModel()
raw_tokens = "Papa ate the caviar with a spoon .".split()
as_words = [Word(w) for w in raw_tokens]
as_phrases = [Phrase((w, )) for w in as_words]
phrases = set(as_phrases)
pprint(beam_search(phrases, 3, lambda x: 0.0, model))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.