Skip to content

Instantly share code, notes, and snippets.

@nebuta
Created February 15, 2022 15:30
Show Gist options
  • Save nebuta/5f113b259e2041f6f1c1a2b63b7b0a56 to your computer and use it in GitHub Desktop.
Save nebuta/5f113b259e2041f6f1c1a2b63b7b0a56 to your computer and use it in GitHub Desktop.
Wordle solver combining a SMT solver and entropy maximization
from z3 import *
import numpy as np
import re
set_param("parallel.enable", True)
set_param("parallel.threads.max", 8)
initial_guess = None
RIGHT = "o"
MISPLACED = "-"
WRONG = "x"
with open("wordlist_hidden.txt") as f:
all_answers_words5 = [l.strip() for l in f.readlines()]
with open("wordlist_all.txt") as f:
all_possible_words5 = [l.strip() for l in f.readlines()]
# SMTソルバの変数を5文字からなる単語に変換する
def to_s(m, vals):
return "".join([chr(m[v].as_long() + ord("a")) for v in vals])
# SMTソルバの変数5個(1,2,3,4,5文字目に対応)を初期化する
# (なお,a-zの26文字なので5ビットで良いはずだが,5ビットだと計算結果がおかしかった)
def make_vals(s):
vals = [BitVec("c%d" % i, 6) for i in range(5)]
for i in range(5):
s.add(vals[i] <= 25)
return vals
# SMTソルバに,使用可能な単語の拘束条件を設定する
def only_listed_words(s, vals, words):
conds = []
for w in words:
cs = []
for c, v in zip(w, vals):
cs.append(v == ord(c) - ord("a"))
conds.append(And(cs))
s.add(Or(conds))
# 答えになり得る単語一覧を拘束条件に含めたSMTソルバを作成する
def create_solver(all_answers, history):
s = Solver()
v = make_vals(s)
only_listed_words(s, v, all_answers)
for r in history:
set_response(s, v, r[0], r[1])
return s, v
# 構築したソルバを走らせて,可能性のある単語を列挙する
def get_candidates(s, v):
res = s.check()
if res == unsat:
# print('No answer')
return []
answers = []
while res == sat:
m = s.model()
answer = to_s(m, v)
# print(answer)
answers.append(answer)
block = []
for i in range(5):
block.append(v[i] != m[v[i]])
s.add(Or(block))
res = s.check()
return sorted(answers)
# トライを与えたときに得られたヒントをSMTソルバに設定する
def set_response(solver, vals, cs, rs):
for i, c, r in zip(range(5), cs, rs):
n = ord(c) - ord("a")
if r == WRONG:
solver.add(vals[i] != n)
count = len([cr for cr in zip(cs, rs) if cr[0] == c and cr[1] != WRONG])
solver.add(Sum([If(vals[j] == n, 1, 0) for j in range(5)]) == count)
elif r == RIGHT:
solver.add(vals[i] == n)
elif r == MISPLACED:
solver.add(vals[i] != n)
solver.add(Or([v == n for v in vals]))
# 正解(answer)に対してトライ(guess)を与えたときのヒントを計算する
def calculate_hint(guess):
count_in_guess = {}
for g in guess:
if g not in count_in_guess:
count_in_guess[g] = 0
count_in_guess[g] += 1
def func(answer):
s = ["*", "*", "*", "*", "*"]
count_in_answer = {}
for a in answer:
if a not in count_in_answer:
count_in_answer[a] = 0
count_in_answer[a] += 1
idx = 0
for a, g in zip(answer, guess):
# print(answer,guess,a,g,idx)
if a == g:
s[idx] = RIGHT
count_in_answer[g] -= 1
idx += 1
idx = 0
for a, g in zip(answer, guess):
if a != g:
if g not in count_in_answer or count_in_answer[g] == 0:
s[idx] = WRONG
elif a != g:
s[idx] = MISPLACED
count_in_answer[g] -= 1
idx += 1
return "".join(s)
return func
# トライ(guess)に対して,可能性のある単語(candidates)が返しうるヒントを集計する
def get_counts(candidates, guess):
counts = {}
func = calculate_hint(guess)
for answer in candidates:
s = func(answer)
if s not in counts:
counts[s] = 0
counts[s] += 1
return counts
# get_counts()の集計結果について,エントロピーを計算する
def calculate_entropy(counts):
vs = np.array(list(counts.values()))
probs = vs / np.sum(vs)
e = np.sum(-probs * np.log2(probs))
return e
# トライの候補(candidates_for_guessing)それぞれに対して,エントロピーを計算したリストを作成する
# 答えとして可能性のある単語をcandidates_for_answerで指定する
def get_entropies_for_guesses(candidates_for_guessing, candidates_for_answer):
entropies = []
for guess in candidates_for_guessing:
counts = get_counts(candidates_for_answer, guess)
e = calculate_entropy(counts)
entropies.append([guess, e, counts])
entropies = sorted(entropies, key=lambda a: -a[1])
return entropies
def find_initial_guess(verbose=False):
global all_answers_words5, all_possible_words5
print("Finding initial guess...")
entropies = get_entropies_for_guesses(all_possible_words5, all_answers_words5)
if verbose:
for e in entropies[0:5]:
print(e)
return entropies[0][0]
def validate_guess(s):
return s in all_possible_words5
def validate_hint(s):
pat = rf"[{re.escape(RIGHT)}{re.escape(MISPLACED)}{re.escape(WRONG)}]{{5}}"
return re.match(pat, s) is not None
def input_actual_try(suggested):
print("Suggested try: ", suggested)
while True:
guess = input("Enter actual try (blank for suggested): ")
if guess == "":
guess = suggested
if validate_guess(guess):
break
print("Invalid try.")
return guess
def input_hint():
while True:
hint = input(
f"Enter hint in five letters ({RIGHT}:green, {MISPLACED}:yellow, {WRONG}:black): "
)
if validate_hint(hint):
break
print("Invalid hint format.")
return hint
# 得られるヒントのエントロピーが最大になるようなトライを求める
def get_recommended_try(response_history, verbose=False):
global all_answers_words5, all_possible_words5
solver, vals = create_solver(all_answers_words5, response_history)
candidates = get_candidates(solver, vals)
print(
"Total %d candidates" % len(candidates)
+ (": " + ", ".join(candidates) if len(candidates) <= 100 else "")
)
if len(candidates) == 1:
print("Found answer: ", candidates[0])
return None
entropies1 = get_entropies_for_guesses(candidates, candidates)
if verbose:
print()
print("Entropies top 3 from possible answer words")
for e in entropies1[0:3]:
print(e)
entropies2 = get_entropies_for_guesses(all_possible_words5, candidates)
if verbose:
print()
print("Entropies top 3 from all words")
for e in entropies2[0:3]:
print(e)
# candidatesとall_possible_words5の中から,エントロピーが最大になるものを選ぶ
# エントロピーがほぼ同じだったら,正解になり得る単語 candidates を優先
suggested = (
entropies2[0][0]
if entropies2[0][1] > entropies1[0][1] + 0.001
else entropies1[0][0]
)
return suggested
def main():
global initial_guess, all_answers_words5, all_possible_words5
response_history = []
if initial_guess is None:
initial_guess = find_initial_guess()
suggested = initial_guess
count = 0
while True:
count += 1
print("Round %d:" % count)
guess = input_actual_try(suggested)
hint = input_hint()
if hint == RIGHT * 5:
print("Answer found.")
break
response_history.append([guess, hint])
suggested = get_recommended_try(response_history)
if suggested is None:
break
print("Done.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment