Skip to content

Instantly share code, notes, and snippets.

@aryamccarthy
Created May 27, 2019 00:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aryamccarthy/a678289681c4c3f167b0f671797e18be to your computer and use it in GitHub Desktop.
Save aryamccarthy/a678289681c4c3f167b0f671797e18be to your computer and use it in GitHub Desktop.
from dataclasses import astuple, dataclass
from typing import Callable, List, NewType, Set, Tuple
import numpy as np
np.random.seed(1337)
# Dummy typings; feel free to switch to PyTorch.
Word = str
Phrase = NewType('Phrase', Tuple[Word])
Score = float # or th.float, etc.
State = np.ndarray # or th.Tensor, etc.
@dataclass(frozen=True)
class Hypothesis:
"""Class for a node in the beam search beam."""
story: List[Phrase]
candidates: Set[Phrase]
score: Score
state: State
class Model:
def __init__(self):
pass
def initial_state(self) -> State:
raise NotImplementedError
def score(self, next_word: Word, state: State) -> Score:
raise NotImplementedError
def transition(self, next_word: Word, state: State) -> State:
raise NotImplementedError
def score_sequence(self, phrases: List[Phrase]) -> Score:
raise NotImplementedError
class DummyModel(Model):
def initial_state(self) -> State:
return np.random.randn(4)
def score(self, next_word: Word, state: State) -> Score:
return np.random.random()
def transition(self, next_word: Word, state: State) -> State:
return state
def score_sequence(self, phrases: List[Phrase]) -> Score:
return np.random.random()
def dummy_heuristic(remainder_to_process: Set[Phrase]) -> Score:
"""
Just return 0, to show the API.
By comparison, Schmaltz et al. use
'a very simple unigram future cost estimate,
g(R) = sum[i∈R] sum[􏰀w∈xi] log p(w).'
"""
return Score(0.0)
def beam_search(
phrases: List[Phrase], # Noun phrases or phrases containing one token
K: int, # Beam size
g: Callable[[Set[Phrase]], Score],
model
):
M = len(phrases) # Number of phrases may not be number of tokens.
beams: List[Hypothesis] = [[] for _ in range(M + 1)]
beams[0] = [Hypothesis([], phrases, 0.0, model.initial_state())]
m = 0
for m in range(M): # for all lengths:
for k in range(len(beams[m])): # for each hypothesis at this point:
these_beams = beams[m]
hypothesis = these_beams[k]
(story, candidates, score, state) = astuple(hypothesis)
for phrase in candidates:
new_score, new_state = score, state
for word in phrase:
new_score += model.score(word, new_state)
new_state = model.transition(word, new_state)
j = m + len(phrase)
new_hypothesis = Hypothesis(story + [phrase], candidates - {phrase}, new_score, new_state)
beams[j].append(new_hypothesis)
# Extract top K from model.
ordered = sorted(beams[j], key=lambda hyp: model.score_sequence(hyp.story) + g(hyp.candidates))
beams[j] = ordered[:K]
return beams
if __name__ == '__main__':
from pprint import pprint # Pretty-print
import random
random.seed(1337)
model = DummyModel()
raw_tokens = "Papa ate the caviar with a spoon .".split()
random.shuffle(raw_tokens)
as_words = [Word(w) for w in raw_tokens]
as_phrases = [Phrase((w, )) for w in as_words]
phrases = set(as_phrases)
pprint(beam_search(phrases, 3, lambda x: 0.0, model))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment