Skip to content

Instantly share code, notes, and snippets.

@felixdae
Created July 4, 2023 15:48
Show Gist options
  • Save felixdae/deb1088891bff4bc7961d248d717b82d to your computer and use it in GitHub Desktop.
Save felixdae/deb1088891bff4bc7961d248d717b82d to your computer and use it in GitHub Desktop.
solve tic tac toe
import copy
from typing import List, Dict
def encode(board: List[List[int]]) -> int:
tmp = []
tmp.extend(board[0])
tmp.extend(board[1])
tmp.extend(board[2])
val = 0
for v in tmp:
assert v in [0, 1, 2]
val = val * 3 + v
return val
def decode(state: int) -> List[List[int]]:
tmp = [0 for _ in range(9)]
p = 8
while state > 0:
assert p >= 0
tmp[p] = state % 3
state //= 3
p -= 1
board = [tmp[:3], tmp[3:6], tmp[6:]]
return board
def is_draw(board: List[List[int]]):
count1, count2 = count_cell(board)
if count1 == 5 and count2 == 4:
return True
assert count1 - count2 in [0, 1]
return False
def is_win_or_lose(board: List[List[int]]):
def check_row(r: int):
if (board[r][0] == board[r][1] == board[r][2]) and board[r][0] != 0:
return True
def check_col(c: int):
if (board[0][c] == board[1][c] == board[2][c]) and board[0][c] != 0:
return True
for i in range(3):
if check_row(i):
return True
if check_col(i):
return True
if (board[0][0] == board[1][1] == board[2][2]) and board[1][1] != 0:
return True
if (board[0][2] == board[1][1] == board[2][0]) and board[1][1] != 0:
return True
return False
def count_cell(board: List[List[int]]):
count0 = 0
count1 = 0
count2 = 0
for i in range(3):
for j in range(3):
if board[i][j] == 1:
count1 += 1
elif board[i][j] == 2:
count2 += 1
else:
assert board[i][j] == 0
count0 += 1
return count1, count2
def is_legal(board: List[List[int]]) -> bool:
r = len(board)
assert r == 3
c = len(board[0])
assert c == 3
count1, count2 = count_cell(board)
if (count1 != count2) and (count1 != count2 + 1):
return False
if is_win_or_lose(board):
return False
return True
def place(board: List[List[int]], r: int, c: int) -> (int, bool):
tmp = copy.deepcopy(board)
assert not is_win_or_lose(tmp)
count1, count2 = count_cell(tmp)
assert count1 - count2 in [0, 1]
if count1 == count2:
player = 1
else:
player = 2
assert player in [1, 2]
assert tmp[r][c] == 0
tmp[r][c] = player
return player, is_win_or_lose(tmp)
def legal_states():
states = {}
for s in range(3 ** 9):
if is_legal(decode(s)):
states[s] = 0
return states
def calc_state_value(s: int, states: Dict[int, int]):
board = decode(s)
v = 100
for i in range(3):
for j in range(3):
if board[i][j] != 0:
continue
player, finished = place(board, i, j)
if finished:
if -1 < v:
v = -1
else:
board[i][j] = player
next_s = encode(board)
board[i][j] = 0
assert states[next_s] in [0, 1, -1]
if states[next_s] < v:
v = states[next_s]
# literally no update
if v == 100:
return states[s]
assert v in [0, 1, -1]
return -v
def value_iterate(states: Dict[int, int]):
num = 0
while True:
tmp = {}
updated = False
for s in states:
new_value = calc_state_value(s, states)
tmp[s] = new_value
if tmp[s] != states[s]:
updated = True
states = tmp
num += 1
if not updated:
break
return num, states
def next_action(states: Dict[int, int], board: List[List[int]]) -> (bool, List[List[int]]):
if is_win_or_lose(board):
return True, None
if is_draw(board):
return True, None
best_v = 100
best_action = None
for i in range(3):
for j in range(3):
if board[i][j] != 0:
continue
player, finished = place(board, i, j)
if finished:
# favor shorter action sequence
if -2 < best_v:
best_v = -2
best_action = (i, j, player)
else:
board[i][j] = player
next_s = encode(board)
board[i][j] = 0
assert states[next_s] in [0, 1, -1]
if states[next_s] < best_v:
best_v = states[next_s]
best_action = (i, j, player)
i, j, player = best_action
board[i][j] = player
return False, board
def action_sequence(states: Dict[int, int], start: List[List[int]]):
while True:
terminated, next_board = next_action(states, start)
if terminated:
break
start = next_board
yield next_board
def to_str(start: List[List[int]]):
tmp = [' '.join([str(c) for c in row]) for row in start]
return "\n".join(tmp)
def main():
# print(encode(board))
# print(encode(decode(3994)) == 3994)
# count1, count2 = count_cell(board)
# print(count1, count2)
#
init_states = legal_states()
# print(len(init_states), 3 ** 9)
num, res = value_iterate(init_states)
# print("iter num", num)
# for k in res:
# print(k, res[k])
board = [
[0, 2, 0], [0, 1, 0], [0, 0, 0]
]
print(f"value is {res[encode(board)]} of board:")
print(to_str(board))
for tmp in action_sequence(res, board):
print("-" * 7)
print(to_str(tmp))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment