Created
December 24, 2019 19:34
-
-
Save dodger487/c6fcf10912ab55b20fefffecd81d281c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Chris Riederer | |
# 2019-06-27 | |
"""Toy experiment with reinforcement learning.""" | |
from collections import Counter | |
import random | |
import time | |
import numpy as np | |
import sklearn.preprocessing | |
import xgboost | |
"""Represent the game.""" | |
class Values(): | |
X = 1 | |
O = -1 | |
blank = 0 | |
tie = 0 | |
def create_new_board(): | |
return tuple(np.zeros(9).astype(int)) # Tuple so that it's iummutable | |
"""Printing.""" | |
def print_board_quick(board): | |
for row in range(3): | |
print board[row:row+3], | |
def print_board(board): | |
def print_separater(): | |
print "+----+----+----+" | |
def print_row(row): | |
# print "|", row[0], "|", row[1], "|", row[2], "|" | |
print "| {: 2d} | {: 2d} | {: 2d} |".format(*row) | |
print_separater() | |
for row in range(3): | |
# print "row", board[3*row:3*row+3] | |
print_row(board[3*row:3*row+3]) | |
print_separater() | |
"""Check if game is over.""" | |
def check_all_same(item1, item2, item3): | |
return ((Values.X == item1 == item2 == item3) | |
or (Values.O == item1 == item2 == item3)) | |
def check_rows(board): | |
for row in range(3): | |
if check_all_same(*board[3*row:3*row+3]): | |
return True | |
return False | |
def check_columns(board): | |
for col in range(3): | |
if check_all_same(*board[col:9:3]): | |
return True | |
return False | |
def check_diags(board): | |
return (check_all_same(board[0], board[4], board[8]) | |
or check_all_same(board[2], board[4], board[6])) | |
def check_win(board): | |
return (check_rows(board) | |
or check_columns(board) | |
or check_diags(board)) | |
"""Query and manipulate the board.""" | |
def available_moves(board): | |
return [i for i, val in enumerate(board) if val == Values.blank] | |
def make_move(board, player, move): | |
new_board = list(board) | |
new_board[move] = player | |
return tuple(new_board) | |
"""Strategies.""" | |
def random_move(board): | |
return random.choice(available_moves(board)) | |
def human(board): | |
moves = available_moves(board) | |
print "Available moves:", moves, ":", | |
while True: | |
move = input() | |
print move | |
try: | |
move = int(move) | |
except: | |
print "I didn't like your input." | |
continue | |
if move in moves: | |
return move | |
else: | |
print "I didn't like your input." | |
def model_factory(model): | |
def strat(board): | |
moves = available_moves(board) | |
state_actions = np.array( | |
[represent_state_action(board, move) for move in moves]) | |
probs = model.predict_proba(state_actions)[:, 1] | |
best_move_index = np.argmax(probs) | |
return moves[best_move_index] | |
return strat | |
"""Manage data representation.""" | |
ohe = sklearn.preprocessing.OneHotEncoder(n_values=9) | |
ohe.fit([[0]]) # This doesn't do anything but prevents in error in sklearn version I'm using. | |
def represent_state_action(board, move): | |
return np.append( | |
np.array(board), ohe.transform(move).toarray()) | |
def update_state_actions(state_actions, this_action): | |
return np.append(state_actions, [this_action], axis=0) | |
def get_labels(state_actions, winner): | |
labels = np.ones(len(state_actions)) | |
if winner == Values.X: | |
labels[1::2] = 0 | |
elif winner == Values.O: | |
labels[0::2] = 0 | |
return labels | |
def get_player_x_data(state_actions, winner): | |
data = state_actions[0::2] | |
label = (winner == Values.X * np.ones(len(data))).astype(int) | |
return data, label | |
def get_player_o_data(state_actions, winner): | |
data = state_actions[1::2] | |
label = (winner == Values.O * np.ones(len(data))).astype(int) | |
return data, label | |
def extract_board(state_action): | |
return state_action[:9].astype(int) | |
"""Run a game.""" | |
def play_game(strategyX, strategyO, display=False, display_wait=0.5, record_data=False): | |
board = create_new_board() | |
player = Values.X | |
strategy = strategyX | |
state_actions = np.empty((0,18), int) | |
while True: | |
if len(available_moves(board)) == 0: | |
if display: print "Tie game" | |
return (state_actions, Values.tie) | |
# Decide on a move. | |
move = strategy(board) | |
# Record move as state-action pairing. | |
if record_data: | |
this_action = represent_state_action(board, move) | |
state_actions = update_state_actions(state_actions, this_action) | |
# Update the board. | |
board = make_move(board, player, move) | |
if display: | |
print_board(board) | |
time.sleep(display_wait) | |
if check_win(board): | |
if display: print "Player", player, "wins!" | |
# TODO: return all moves with label | |
return (state_actions, player) | |
if player == Values.X: | |
player = Values.O | |
strategy = strategyO | |
else: | |
player = Values.X | |
strategy = strategyX | |
def compare_strategies(strategy1, strategy2, num_games=100): | |
results = [] | |
for game in range(num_games): | |
_, winner = play_game(strategy1, strategy2) | |
results.append(winner) | |
outcome_counts = Counter(results) | |
print outcome_counts | |
print "strategy1 vs strategy2:" | |
print sum([k * v for k, v in outcome_counts.items()]) / float(num_games) | |
results = [] | |
for game in range(num_games): | |
_, winner = play_game(strategy2, strategy1) | |
results.append(winner) | |
outcome_counts = Counter(results) | |
print outcome_counts | |
print "strategy2 vs strategy1:" | |
print sum([k * v for k, v in outcome_counts.items()]) / float(num_games) | |
def generate_data(num_games, strategy1, strategy2, display=False): | |
data = np.empty((0,18), int) | |
labels = np.empty((1,0), int) | |
for i in range(num_games): | |
this_data, winner = play_game(strategy1, strategy2, record_data=True, display=display) | |
this_labels = get_labels(this_data, winner) | |
# Append data to train X | |
data = np.append(data, this_data, axis=0) | |
labels = np.append(labels, this_labels) | |
return data, labels | |
# print state_actions | |
# play_game(model_factory(first_model), random_move, display=True) | |
print "Compare random:" | |
compare_strategies(random_move, random_move, 1000) | |
print "GENERATE: Train strategy on random players." | |
data, labels = generate_data(1000, random_move, random_move) | |
ROUNDS = 5 | |
models = [] | |
strats = [] | |
for i in range(ROUNDS): | |
print "***** Round", i, "*****" | |
model = xgboost.XGBClassifier(max_depth=9) | |
model.fit(data, labels) | |
models.append(model) | |
new_strat = model_factory(model) | |
strats.append(new_strat) | |
print "COMPARE: trained strategy to random." | |
compare_strategies(new_strat, random_move, 500) | |
print "COMPARE: trained strategy to itself." | |
compare_strategies(new_strat, new_strat, 500) | |
if i > 0: | |
print "COMPARE: trained strategy to last round." | |
compare_strategies(new_strat, strats[i-1], 500) | |
print "GENERATE: Play trained strategy against itself." | |
more_data, more_labels = generate_data(1000, new_strat, new_strat) | |
data = np.append(data, more_data, axis=0) | |
labels = np.append(labels, more_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment