Skip to content

Instantly share code, notes, and snippets.

@nariaki3551
Last active July 12, 2023 01:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nariaki3551/a23fb85b481396022299b8a204259d47 to your computer and use it in GitHub Desktop.
Save nariaki3551/a23fb85b481396022299b8a204259d47 to your computer and use it in GitHub Desktop.
from heapq import heapify, heappush, heappop, heappushpop
def beam_search(root, k, api):
"""
Args:
root : root node
k : number of remain paths during search
api : apis for beam search
Notes:
api must have functions as follows.
(1) init : this is called at the begenning of this function
(2) step : return path-list or path-generator of extended path from inputed path
(3) score : return score for path, higher scores indicate better
(4) count : this function is called for every end of loop
(5) terminate : return true if it should terminate to search else false
"""
paths = [(None, root)]
heapify(paths)
api.init()
while not api.terminate():
top_paths = []
heapify(top_paths)
for _, path in paths:
for extend_path in api.step(path):
score = api.score(extend_path)
if len(top_paths) < k:
heappush(top_paths, (score, extend_path))
else:
heappushpop(top_paths, (score, extend_path))
paths = top_paths
api.count()
result_paths = []
result_paths_score = []
for _, path in paths:
result_paths.append(path)
result_paths_score.append(score)
return result_paths, result_paths_score
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment