Last active
August 26, 2022 05:25
-
-
Save peter-lang/e091ecf0b86f9442ff0eefaafec196b1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import tqdm | |
# states: | |
# neutral: 0 | |
# player1: 1, 2, 3 (have 2 of each) | |
# player2: -1, -2, -3 (have 2 of each) | |
# board: 3x3 => 9 long tuple (board[i][j] = tuple[i*3 + j]) | |
# Rotations | |
# e r1 (90) r2 (180) r3 (270) | |
# 0 1 2 6 3 0 8 7 6 2 5 8 | |
# 3 4 5 => 7 4 1 => 5 4 3 => 1 4 7 | |
# 6 7 8 8 5 2 2 1 0 0 3 6 | |
# Reflections | |
# Tx T-1=Tx(r1) Ty=Tx(r2) T+1=Tx(r3) | |
# 2 1 0 0 3 6 6 7 8 8 5 2 | |
# 5 4 3 => 1 4 7 => 3 4 5 => 7 4 1 | |
# 8 7 6 2 5 8 0 1 2 6 3 0 | |
def symmetries(e): | |
return { | |
e, | |
(e[6], e[3], e[0], e[7], e[4], e[1], e[8], e[5], e[2]), # r1 | |
(e[8], e[7], e[6], e[5], e[4], e[3], e[2], e[1], e[0]), # r2 | |
(e[2], e[5], e[8], e[1], e[4], e[7], e[0], e[3], e[6]), # r3 | |
(e[2], e[1], e[0], e[5], e[4], e[3], e[8], e[7], e[6]), # Tx | |
(e[0], e[3], e[6], e[1], e[4], e[7], e[2], e[5], e[8]), # T-1 | |
(e[6], e[7], e[8], e[3], e[4], e[5], e[0], e[1], e[2]), # Ty | |
(e[8], e[5], e[2], e[7], e[4], e[1], e[6], e[3], e[0]), # T+1 | |
} | |
def score(table): | |
for line_idxs in ( | |
(0, 1, 2), | |
(3, 4, 5), | |
(6, 7, 8), | |
(0, 3, 6), | |
(1, 4, 7), | |
(2, 5, 8), | |
(0, 4, 8), | |
(2, 4, 6), | |
): | |
if all(table[li] > 0 for li in line_idxs): | |
return 1 | |
if all(table[li] < 0 for li in line_idxs): | |
return -1 | |
return None | |
def canonical_table(table): | |
return sorted(symmetries(table))[0] | |
def moves(table, figures, is_max): | |
for figure_idx, count in enumerate(figures): | |
if count == 0: | |
continue | |
visited_tables = set() | |
figure = figure_idx + 1 if is_max else -(figure_idx + 1) | |
for idx in range(9): | |
if abs(figure) > abs(table[idx]): | |
next_table = canonical_table(tuple(t if i != idx else figure for i, t in enumerate(table))) | |
if next_table not in visited_tables: | |
visited_tables.add(next_table) | |
remaining_figures = tuple(fc if f_idx != figure_idx else fc - 1 for f_idx, fc in enumerate(figures)) | |
yield next_table, remaining_figures | |
progress = tqdm.tqdm() | |
def minimax(table, max_player, min_player, depth, resolved_states): | |
res = resolved_states.get((table, max_player, min_player), None) | |
if res is not None: | |
return res[0], table | |
is_max = depth % 2 == 0 | |
sc = score(table) | |
if sc is not None: | |
progress.update(1) | |
progress.set_postfix({"size": len(resolved_states)}) | |
return sc, table | |
next_states = list(moves(table, max_player if is_max else min_player, is_max)) | |
if len(next_states) == 0: | |
progress.update(1) | |
progress.set_postfix({"size": len(resolved_states)}) | |
return 0, table | |
if is_max: | |
sc, next_table = max( | |
(minimax(next_t, next_f, min_player, depth + 1, resolved_states) for next_t, next_f in next_states), | |
key=lambda x: x[0]) | |
else: | |
sc, next_table = min( | |
(minimax(next_t, max_player, next_f, depth + 1, resolved_states) for next_t, next_f in next_states), | |
key=lambda x: x[0]) | |
resolved_states[(table, max_player, min_player)] = sc, next_table | |
return sc, table | |
if __name__ == '__main__': | |
states = dict() | |
print(minimax((0, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 2), (2, 2, 2), 0, states)) | |
with open('solution.dat', 'wb') as fp: | |
for key, value in states.items(): | |
table, maxp, minp = key | |
values = list(table) + list(maxp) + list(minp) + [value[0]] + list(value[1]) | |
fp.write(struct.pack('>25i', *values)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import re | |
def update_tuple(orig, idx, val): | |
return tuple(list(orig[:idx]) + [val] + list(orig[(idx + 1):])) | |
def update_tuple_add_value(orig, idx, val): | |
return tuple(list(orig[:idx]) + [orig[idx] + val] + list(orig[(idx + 1):])) | |
def symmetries(e): | |
return [ | |
e, | |
(e[6], e[3], e[0], e[7], e[4], e[1], e[8], e[5], e[2]), # r1 | |
(e[8], e[7], e[6], e[5], e[4], e[3], e[2], e[1], e[0]), # r2 | |
(e[2], e[5], e[8], e[1], e[4], e[7], e[0], e[3], e[6]), # r3 | |
(e[2], e[1], e[0], e[5], e[4], e[3], e[8], e[7], e[6]), # Tx | |
(e[0], e[3], e[6], e[1], e[4], e[7], e[2], e[5], e[8]), # T-1 | |
(e[6], e[7], e[8], e[3], e[4], e[5], e[0], e[1], e[2]), # Ty | |
(e[8], e[5], e[2], e[7], e[4], e[1], e[6], e[3], e[0]), # T+1 | |
] | |
def score(state): | |
for line_idxs in ( | |
(0, 1, 2), | |
(3, 4, 5), | |
(6, 7, 8), | |
(0, 3, 6), | |
(1, 4, 7), | |
(2, 5, 8), | |
(0, 4, 8), | |
(2, 4, 6), | |
): | |
if all(state[0][li] > 0 for li in line_idxs): | |
return 1 | |
if all(state[0][li] < 0 for li in line_idxs): | |
return -1 | |
if sum(state[1]) == 0 and sum(state[2]) == 0: | |
return 0 | |
return None | |
def canonical_table(table): | |
return sorted(symmetries(table))[0] | |
def diff(table, orig): | |
diffs = 0 | |
diff_i = None | |
for i in range(9): | |
if table[i] != orig[i]: | |
diff_i = i | |
diffs += 1 | |
if diffs > 1: | |
return None | |
return diff_i | |
def restore_move(table, orig): | |
for sym in symmetries(table): | |
if (idx := diff(sym, orig)) is not None: | |
return idx, sym[idx] | |
def print_table(table): | |
print(" | 1 | 2 | 3") | |
for row_idx in range(3): | |
cols = table[row_idx * 3: (row_idx + 1) * 3] | |
row_name = chr(ord('a') + row_idx) | |
print(f"{row_name} | {' | '.join([' 0' if i == 0 else '+' + str(i) if i > 0 else str(i) for i in cols])}") | |
def take_user_guess(inp, state): | |
m = re.match('^\s*([a-c])\s*([1-3])\s*([1-3])\s*$', inp) | |
if not m: | |
return None | |
idx = (ord(m.group(1)) - ord('a')) * 3 + int(m.group(2)) - 1 | |
val = int(m.group(3)) | |
if val < abs(state[0][idx]): | |
return None | |
if state[1][val - 1] <= 0: | |
return None | |
new_table = update_tuple(state[0], idx, val) | |
new_player = update_tuple_add_value(state[1], abs(val) - 1, -1) | |
return new_table, new_player, state[2] | |
def computer_move(comp_dict, state): | |
res = comp_dict[(canonical_table(state[0]), state[1], state[2])] | |
idx, val = restore_move(res[1], state[0]) | |
new_table = update_tuple(state[0], idx, val) | |
new_comp = update_tuple_add_value(state[2], abs(val) - 1, -1) | |
print(f"predicted score {res[0]}") | |
return new_table, state[1], new_comp | |
def load_comp_move_dict(): | |
struct_fmt = '>25i' | |
struct_len = struct.calcsize(struct_fmt) | |
struct_unpack = struct.Struct(struct_fmt).unpack_from | |
states = dict() | |
with open('solution.dat', 'rb') as fp: | |
while True: | |
b = fp.read(struct_len) | |
if not b: | |
break | |
values = struct_unpack(b) | |
states[(tuple(values[0:9]), tuple(values[9:12]), tuple(values[12:15]))] = (values[15], tuple(values[16:25])) | |
return states | |
if __name__ == '__main__': | |
comp_move_dict = load_comp_move_dict() | |
game = initial = ((0, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 2), (2, 2, 2)) | |
while True: | |
print("human", game[1], "computer", game[2]) | |
print_table(game[0]) | |
user_inp = input('move (e.g.: a2 3):') | |
if user_inp == 'reset': | |
game = initial | |
continue | |
game_after_user = take_user_guess(user_inp, game) | |
if game_after_user is None: | |
print("illegal move") | |
continue | |
game = game_after_user | |
if (sc1 := score(game)) is not None: | |
print(f"game over: {'human wins' if sc1 > 0 else 'computer wins' if sc1 < 0 else 'draw'}") | |
break | |
game = computer_move(comp_move_dict, game) | |
if (sc2 := score(game)) is not None: | |
print_table(game[0]) | |
print(f"game over: {'human wins' if sc2 > 0 else 'computer wins' if sc2 < 0 else 'draw'}") | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment