Created
January 31, 2022 17:44
-
-
Save maxastyler/8c026cd88cf062acc0e5738b31a73353 to your computer and use it in GitHub Desktop.
Code from wordle blog post
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from dataclasses import dataclass, field | |
from collections import defaultdict | |
from typing import Literal, Tuple, Callable, TypeVar, Optional | |
K, V = TypeVar("K"), TypeVar("V") | |
def merge(A: dict[K, V], B: dict[K, V], f: Callable[[V, V], V]) -> dict[K, V]: | |
"""Function to merge two dictionaries, A and B. | |
If any keys are in both, apply f(a, b) to the values and put the | |
result of this function under the key | |
""" | |
merged = {k: A.get(k, B.get(k)) for k in A.keys() ^ B.keys()} | |
merged.update({k: f(A[k], B[k]) for k in A.keys() & B.keys()}) | |
return merged | |
@dataclass | |
class Unplaced: | |
"""Function to describe an unplaced entry. | |
If complete is true, then we know the amount of this letter that the | |
answer must contain, otherwise we only know that the answer must contain | |
at least `number` of this letter. | |
The character must not be in any position in positions | |
""" | |
complete: bool | |
number: int | |
positions: set[int] | |
def __add__(self, other: "Unplaced") -> "Unplaced": | |
return Unplaced( | |
self.complete or other.complete, | |
max(self.number, other.number), | |
self.positions | other.positions, | |
) | |
@dataclass | |
class Knowledge: | |
"""Class to describe the set of knowledge about the answer | |
`not_contained` contains the letters not contained in the answer | |
`placed` contains the set of letters and their positions in the answer | |
`unplaced` contains the letters which we don't know the positions of - | |
only the positions they aren't in, and the number there are | |
""" | |
not_contained: set[str] = field(default_factory=set) | |
placed: set[tuple[int, str]] = field(default_factory=set) | |
unplaced: dict[str, Unplaced] = field(default_factory=dict) | |
def __add__(self, other: "Knowledge") -> "Knowledge": | |
return Knowledge( | |
not_contained=self.not_contained | other.not_contained, | |
placed=self.placed | other.placed, | |
unplaced=merge(self.unplaced, other.unplaced, lambda a, b: a + b), | |
) | |
def complete(self) -> bool: | |
"""Return true if we have complete knowledge of the word | |
""" | |
return len(self.placed) == 5 | |
def contains(self, word: str) -> bool: | |
"""Return true if the word fits the knowledge | |
False otherwise | |
""" | |
if any(i in word for i in self.not_contained): | |
return False | |
if any(word[i] != c for (i, c) in self.placed): | |
return False | |
for (c, unplaced) in self.unplaced.items(): | |
count = 0 | |
for i, l in enumerate(word): | |
if l == c: | |
if i in unplaced.positions: | |
return False | |
else: | |
count += 1 | |
if unplaced.complete and count != unplaced.number: | |
return False | |
elif (not unplaced.complete) and unplaced.number > count: | |
return False | |
return True | |
def oracle(word: str, answer: str) -> Knowledge: | |
"""Return the wordle results for the given word and answer""" | |
answer_list: list[str | None] = list(answer) | |
placed: set[tuple[int, str]] = set() | |
others: dict[str, list[bool | set[int]]] = defaultdict( | |
lambda: [False, set()] | |
) | |
for (i, (w, a)) in enumerate(zip(word, answer)): | |
if w == a: | |
placed.add((i, w)) | |
else: | |
for j, c in enumerate(answer_list): | |
if c == w: | |
others[c][1].add(i) | |
answer_list[j] = None | |
break | |
else: | |
others[w][0] = True | |
not_contained: set[str] = set() | |
unplaced: dict[str, Unplaced] = dict() | |
for (w, (b, l)) in others.items(): | |
if len(l) == 0: | |
not_contained.add(w) | |
else: | |
unplaced[w] = Unplaced(b, len(l), l) | |
return Knowledge( | |
not_contained=not_contained, placed=placed, unplaced=unplaced | |
) | |
def l(w1: str, w2: str, k: Knowledge, W: list[str], guesses: list[str], answers: list[str]) -> Optional[int]: | |
"""Find the number of guesses needed to get from w1 to w2""" | |
if w1 == w2: | |
return 1 | |
else: | |
updated_knowledge = oracle(w1, w2) + k | |
if (best_guess := b(updated_knowledge, W, guesses, answers)) is not None and ( | |
next_length := l(best_guess, w2, updated_knowledge, W, guesses, answers) | |
): | |
return 1 + next_length | |
else: | |
return None | |
def average_length( | |
w: str, answers: list[str], k: Knowledge, W: list[str], guesses: list[str] | |
) -> Optional[int]: | |
total_distance = 0 | |
for a in answers: | |
if (distance := l(w, a, k, W, guesses, answers)) is not None: | |
total_distance += distance | |
else: | |
return None | |
return total_distance | |
def b( | |
k: Knowledge, | |
W: list[str], | |
guesses: list[str] = None, | |
answers: list[str] = None, | |
) -> Optional[str]: | |
"""Find the best guess out of the collection of words W""" | |
if guesses is None: | |
guesses = [] | |
elif len(guesses) > 6: | |
# turn limit has been reached - there's no best word here | |
return None | |
if answers is None: | |
answers = [w for w in W if k.contains(w)] | |
else: | |
answers = [w for w in answers if k.contains(w)] | |
if len(answers) == 1: | |
return answers[0] | |
min_vals: Optional[tuple[str, int]] = None | |
for w in W: | |
if ( | |
w not in guesses | |
and (al := average_length(w, answers, k, W, guesses + [w])) | |
is not None | |
): | |
if min_vals is None or al < min_vals[1]: | |
min_vals = (w, al) | |
if min_vals is None: | |
return None | |
else: | |
return min_vals[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment