-
-
Save nebuta/5f113b259e2041f6f1c1a2b63b7b0a56 to your computer and use it in GitHub Desktop.
Wordle solver combining a SMT solver and entropy maximization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from z3 import * | |
import numpy as np | |
import re | |
set_param("parallel.enable", True) | |
set_param("parallel.threads.max", 8) | |
initial_guess = None | |
RIGHT = "o" | |
MISPLACED = "-" | |
WRONG = "x" | |
with open("wordlist_hidden.txt") as f: | |
all_answers_words5 = [l.strip() for l in f.readlines()] | |
with open("wordlist_all.txt") as f: | |
all_possible_words5 = [l.strip() for l in f.readlines()] | |
# SMTソルバの変数を5文字からなる単語に変換する | |
def to_s(m, vals): | |
return "".join([chr(m[v].as_long() + ord("a")) for v in vals]) | |
# SMTソルバの変数5個(1,2,3,4,5文字目に対応)を初期化する | |
# (なお,a-zの26文字なので5ビットで良いはずだが,5ビットだと計算結果がおかしかった) | |
def make_vals(s): | |
vals = [BitVec("c%d" % i, 6) for i in range(5)] | |
for i in range(5): | |
s.add(vals[i] <= 25) | |
return vals | |
# SMTソルバに,使用可能な単語の拘束条件を設定する | |
def only_listed_words(s, vals, words): | |
conds = [] | |
for w in words: | |
cs = [] | |
for c, v in zip(w, vals): | |
cs.append(v == ord(c) - ord("a")) | |
conds.append(And(cs)) | |
s.add(Or(conds)) | |
# 答えになり得る単語一覧を拘束条件に含めたSMTソルバを作成する | |
def create_solver(all_answers, history): | |
s = Solver() | |
v = make_vals(s) | |
only_listed_words(s, v, all_answers) | |
for r in history: | |
set_response(s, v, r[0], r[1]) | |
return s, v | |
# 構築したソルバを走らせて,可能性のある単語を列挙する | |
def get_candidates(s, v): | |
res = s.check() | |
if res == unsat: | |
# print('No answer') | |
return [] | |
answers = [] | |
while res == sat: | |
m = s.model() | |
answer = to_s(m, v) | |
# print(answer) | |
answers.append(answer) | |
block = [] | |
for i in range(5): | |
block.append(v[i] != m[v[i]]) | |
s.add(Or(block)) | |
res = s.check() | |
return sorted(answers) | |
# トライを与えたときに得られたヒントをSMTソルバに設定する | |
def set_response(solver, vals, cs, rs): | |
for i, c, r in zip(range(5), cs, rs): | |
n = ord(c) - ord("a") | |
if r == WRONG: | |
solver.add(vals[i] != n) | |
count = len([cr for cr in zip(cs, rs) if cr[0] == c and cr[1] != WRONG]) | |
solver.add(Sum([If(vals[j] == n, 1, 0) for j in range(5)]) == count) | |
elif r == RIGHT: | |
solver.add(vals[i] == n) | |
elif r == MISPLACED: | |
solver.add(vals[i] != n) | |
solver.add(Or([v == n for v in vals])) | |
# 正解(answer)に対してトライ(guess)を与えたときのヒントを計算する | |
def calculate_hint(guess): | |
count_in_guess = {} | |
for g in guess: | |
if g not in count_in_guess: | |
count_in_guess[g] = 0 | |
count_in_guess[g] += 1 | |
def func(answer): | |
s = ["*", "*", "*", "*", "*"] | |
count_in_answer = {} | |
for a in answer: | |
if a not in count_in_answer: | |
count_in_answer[a] = 0 | |
count_in_answer[a] += 1 | |
idx = 0 | |
for a, g in zip(answer, guess): | |
# print(answer,guess,a,g,idx) | |
if a == g: | |
s[idx] = RIGHT | |
count_in_answer[g] -= 1 | |
idx += 1 | |
idx = 0 | |
for a, g in zip(answer, guess): | |
if a != g: | |
if g not in count_in_answer or count_in_answer[g] == 0: | |
s[idx] = WRONG | |
elif a != g: | |
s[idx] = MISPLACED | |
count_in_answer[g] -= 1 | |
idx += 1 | |
return "".join(s) | |
return func | |
# トライ(guess)に対して,可能性のある単語(candidates)が返しうるヒントを集計する | |
def get_counts(candidates, guess): | |
counts = {} | |
func = calculate_hint(guess) | |
for answer in candidates: | |
s = func(answer) | |
if s not in counts: | |
counts[s] = 0 | |
counts[s] += 1 | |
return counts | |
# get_counts()の集計結果について,エントロピーを計算する | |
def calculate_entropy(counts): | |
vs = np.array(list(counts.values())) | |
probs = vs / np.sum(vs) | |
e = np.sum(-probs * np.log2(probs)) | |
return e | |
# トライの候補(candidates_for_guessing)それぞれに対して,エントロピーを計算したリストを作成する | |
# 答えとして可能性のある単語をcandidates_for_answerで指定する | |
def get_entropies_for_guesses(candidates_for_guessing, candidates_for_answer): | |
entropies = [] | |
for guess in candidates_for_guessing: | |
counts = get_counts(candidates_for_answer, guess) | |
e = calculate_entropy(counts) | |
entropies.append([guess, e, counts]) | |
entropies = sorted(entropies, key=lambda a: -a[1]) | |
return entropies | |
def find_initial_guess(verbose=False): | |
global all_answers_words5, all_possible_words5 | |
print("Finding initial guess...") | |
entropies = get_entropies_for_guesses(all_possible_words5, all_answers_words5) | |
if verbose: | |
for e in entropies[0:5]: | |
print(e) | |
return entropies[0][0] | |
def validate_guess(s): | |
return s in all_possible_words5 | |
def validate_hint(s): | |
pat = rf"[{re.escape(RIGHT)}{re.escape(MISPLACED)}{re.escape(WRONG)}]{{5}}" | |
return re.match(pat, s) is not None | |
def input_actual_try(suggested): | |
print("Suggested try: ", suggested) | |
while True: | |
guess = input("Enter actual try (blank for suggested): ") | |
if guess == "": | |
guess = suggested | |
if validate_guess(guess): | |
break | |
print("Invalid try.") | |
return guess | |
def input_hint(): | |
while True: | |
hint = input( | |
f"Enter hint in five letters ({RIGHT}:green, {MISPLACED}:yellow, {WRONG}:black): " | |
) | |
if validate_hint(hint): | |
break | |
print("Invalid hint format.") | |
return hint | |
# 得られるヒントのエントロピーが最大になるようなトライを求める | |
def get_recommended_try(response_history, verbose=False): | |
global all_answers_words5, all_possible_words5 | |
solver, vals = create_solver(all_answers_words5, response_history) | |
candidates = get_candidates(solver, vals) | |
print( | |
"Total %d candidates" % len(candidates) | |
+ (": " + ", ".join(candidates) if len(candidates) <= 100 else "") | |
) | |
if len(candidates) == 1: | |
print("Found answer: ", candidates[0]) | |
return None | |
entropies1 = get_entropies_for_guesses(candidates, candidates) | |
if verbose: | |
print() | |
print("Entropies top 3 from possible answer words") | |
for e in entropies1[0:3]: | |
print(e) | |
entropies2 = get_entropies_for_guesses(all_possible_words5, candidates) | |
if verbose: | |
print() | |
print("Entropies top 3 from all words") | |
for e in entropies2[0:3]: | |
print(e) | |
# candidatesとall_possible_words5の中から,エントロピーが最大になるものを選ぶ | |
# エントロピーがほぼ同じだったら,正解になり得る単語 candidates を優先 | |
suggested = ( | |
entropies2[0][0] | |
if entropies2[0][1] > entropies1[0][1] + 0.001 | |
else entropies1[0][0] | |
) | |
return suggested | |
def main(): | |
global initial_guess, all_answers_words5, all_possible_words5 | |
response_history = [] | |
if initial_guess is None: | |
initial_guess = find_initial_guess() | |
suggested = initial_guess | |
count = 0 | |
while True: | |
count += 1 | |
print("Round %d:" % count) | |
guess = input_actual_try(suggested) | |
hint = input_hint() | |
if hint == RIGHT * 5: | |
print("Answer found.") | |
break | |
response_history.append([guess, hint]) | |
suggested = get_recommended_try(response_history) | |
if suggested is None: | |
break | |
print("Done.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment