Last active
July 12, 2023 01:58
-
-
Save nariaki3551/a23fb85b481396022299b8a204259d47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from heapq import heapify, heappush, heappop, heappushpop | |
def beam_search(root, k, api): | |
""" | |
Args: | |
root : root node | |
k : number of remain paths during search | |
api : apis for beam search | |
Notes: | |
api must have functions as follows. | |
(1) init : this is called at the begenning of this function | |
(2) step : return path-list or path-generator of extended path from inputed path | |
(3) score : return score for path, higher scores indicate better | |
(4) count : this function is called for every end of loop | |
(5) terminate : return true if it should terminate to search else false | |
""" | |
paths = [(None, root)] | |
heapify(paths) | |
api.init() | |
while not api.terminate(): | |
top_paths = [] | |
heapify(top_paths) | |
for _, path in paths: | |
for extend_path in api.step(path): | |
score = api.score(extend_path) | |
if len(top_paths) < k: | |
heappush(top_paths, (score, extend_path)) | |
else: | |
heappushpop(top_paths, (score, extend_path)) | |
paths = top_paths | |
api.count() | |
result_paths = [] | |
result_paths_score = [] | |
for _, path in paths: | |
result_paths.append(path) | |
result_paths_score.append(score) | |
return result_paths, result_paths_score |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment