Last active
May 30, 2022 11:11
-
-
Save kirillbobyrev/a24a811fbd7cd816916c1e04d87efa7a to your computer and use it in GitHub Desktop.
Reinforcement Learning vs Tic Tac Toe: Temporal Difference (TD) Agent that beats Tic Tac Toe game through self-play
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Author: Kirill Bobyrev (https://github.com/kirillbobyrev) | |
This module implements "An Extended Example: Tic Tac Toe" from `Reinforcement | |
Learning: An Introduction`_ book by Richard S. Sutton and Andrew G. Barto | |
(January 1, 2018 complete draft) described in Section 1.5. The implemented | |
Reinforcement Learning algorithm is TD(0) and it is trained via self-play | |
between two agents. The update rule is slightly modified given the environment | |
specifics to comply with the one introduced in the Chapter 1, but as shown | |
later is equivalent to the one used in generic settings. | |
Example: | |
In order to run this script you would require a recent Python 3 interpreter | |
(versions 3.6 and newer) and few PyPi packages (numpy, tqdm). To train a | |
TD(0) agent and launch an interactive session to play against the AI simply | |
run:: | |
$ python tic_tac_toe.py | |
If you would like to take the first turn against the AI run:: | |
$ python tic_tac_toe.py --take_first_turn | |
Learning the policy for the Reinforcement Learning action would take around | |
a minute by default (20000 episodes), use --episodes to alter the number | |
of training simulations:: | |
$ python tic_tac_toe.py --episodes 1000 | |
.. _Reinforcement Learning: An Introduction: | |
http://incompleteideas.net/book/the-book-2nd.html | |
''' | |
import argparse | |
import copy | |
import numpy | |
import typing | |
import tqdm | |
class Board(object): | |
''' | |
The classic 3 by 3 Tic Tac Toe board interface implementation, which is | |
used as a part of Reinforcement Learning environment. It provides the | |
necessary routine methods for accessing the internal states and allows | |
safely modifying it while maintaining a valid state. | |
Cell coordinates are zero-based indices: (x, y). Top left cell's | |
coordinates are (0, 0), bottom right - (2, 2), i.e. the whole board looks | |
like this: | |
(0, 0) | (0, 1) | (0, 2) | |
-------------------------- | |
(1, 0) | (1, 1) | (1, 2) | |
-------------------------- | |
(2, 0) | (2, 1) | (2, 2) | |
''' | |
def __init__(self, cells: numpy.array = None) -> None: | |
# Use classic 3x3 size. | |
self.size: int = 3 | |
self.first_player_turn: bool = True | |
self.cells: numpy.array | |
if cells is not None: | |
assert (cells.shape == (self.size, self.size)) | |
self.cells = cells | |
else: | |
self.cells = numpy.zeros((self.size, self.size), dtype=numpy.int8) | |
def take_turn(self, cell: typing.Tuple[int, int]): | |
''' | |
Modifies current board given player's decision. | |
Expects given cell to be empty, otherwise produces an exception. | |
''' | |
assert (self.is_possible(cell)) | |
player_identifier = 1 | |
if not self.first_player_turn: | |
player_identifier = -1 | |
self.cells[cell] = player_identifier | |
# Switch current player after the turn. | |
self.first_player_turn = not self.first_player_turn | |
def is_possible(self, action: typing.Tuple[int, int]) -> bool: | |
''' | |
Checks whether an action is valid on this board. | |
Args: | |
action: Coordinates of the action to check for validity. | |
Returns: | |
bool: True if it is possible to put 'X' or 'O' into the given cell, | |
False otherwise. | |
''' | |
return self.cells[action] == 0 | |
def possible_actions(self) -> numpy.array: | |
''' | |
Outputs a all possible actions from current board state by choosing the | |
ones not previously taken by either player. | |
Returns: | |
numpy.array: An array of possible actions. | |
''' | |
return numpy.array([(i, j) | |
for i in range(self.size) | |
for j in range(self.size) | |
if self.is_possible((i, j))]) | |
def is_over(self) -> typing.Tuple[bool, int]: | |
''' | |
Determines whether the game is over and hence no possible further | |
action can be taken by either side. | |
Returns: | |
bool: True if the game is over, False otherwise. | |
int: If the game is over, returns identifier of the winner (1 or | |
-1 for the first and the second player respectively), 0 | |
otherwise. | |
''' | |
# Check for all horizontal sequences of 3 consequent non-empty cells | |
for i in range(self.size): | |
OK = True | |
player_id = self.cells[i][0] | |
if player_id == 0: | |
continue | |
for j in range(self.size): | |
if self.cells[i][j] != player_id: | |
OK = False | |
if OK: | |
return True, player_id | |
# Vertical sequences | |
for i in range(self.size): | |
OK = True | |
player_id = self.cells[0][i] | |
if player_id == 0: | |
continue | |
for j in range(self.size): | |
if self.cells[j][i] != player_id: | |
OK = False | |
if OK: | |
return True, player_id | |
# Diagonal: left top to right bottom | |
OK = True | |
player_id = self.cells[0][0] | |
if player_id != 0: | |
for i in range(self.size): | |
if self.cells[i][i] != player_id: | |
OK = False | |
if OK: | |
return True, player_id | |
# Diagonal: left bottom to right top | |
OK = True | |
player_id = self.cells[self.size - 1][0] | |
if player_id != 0: | |
for i in range(self.size): | |
if self.cells[self.size - i - 1][i] != player_id: | |
OK = False | |
if OK: | |
return True, player_id | |
# If there is an empty cell, the game is not over yet. | |
for i in range(self.size): | |
for j in range(self.size): | |
if self.cells[i][j] == 0: | |
return False, 0 | |
# Otherwise all cells are taken and no player has won: it's a draw! | |
return True, 0 | |
def hash(self) -> int: | |
''' | |
Bijectively maps board state to its unique identifier. | |
Returns: | |
int: Unique identifier of the current Board state. | |
''' | |
result = 0 | |
for i in range(self.size): | |
for j in range(self.size): | |
result *= 3 | |
result += self.cells[i][j] % 3 | |
return result | |
def __repr__(self) -> str: | |
''' | |
Returns the Tic Tac Toe board in a human-readable representation using | |
the following form (indices are replaced with 'X's, 'O's and | |
whitespaces for empty cells): | |
0 | 1 | 2 | |
----------- | |
3 | 4 | 5 | |
----------- | |
6 | 7 | 8 | |
''' | |
result = '' | |
mapping = [' ', 'X', 'O'] | |
for i in range(self.size): | |
for j in range(self.size): | |
result += ' {} '.format(mapping[self.cells[i][j]]) | |
if j != self.size - 1: | |
result += '|' | |
else: | |
result += '\n' | |
if i != self.size - 1: | |
result += ('-' * (2 + self.size * self.size)) + '\n' | |
return result | |
def get_all_states() -> typing.Tuple[typing.Set, typing.Set]: | |
''' | |
Devises all valid board states and computes hashes for each of them. Also | |
extracts terminal states useful for the update rule simplification. | |
Returns: | |
set: A set of all possible boards' hashes. | |
set: A set of hashes of all boards after a final turn, i.e. terminal | |
boards. | |
''' | |
boards = [Board()] | |
states = set() | |
terminal_states = set() | |
epoch = 0 | |
while boards: | |
print(f'Epoch: {epoch}') | |
epoch += 1 | |
next_generation = [] | |
for board in boards: | |
board_hash = board.hash() | |
if board_hash in states: | |
continue | |
states.add(board_hash) | |
over, _ = board.is_over() | |
if over: | |
terminal_states.add(board_hash) | |
continue | |
for action in board.possible_actions(): | |
next_board = copy.deepcopy(board) | |
next_board.take_turn(tuple(action)) | |
next_generation.append(next_board) | |
boards = next_generation | |
return states, terminal_states | |
class TicTacToe(object): | |
''' | |
TicTacToe is a Reinforcement Learning environment for this game, which | |
reacts to players' moves, updates the internal state (Board) and samples | |
reward. | |
''' | |
def __init__(self): | |
self.board: Board = Board() | |
def step(self, | |
action: typing.Tuple[int, int]) -> typing.Tuple[int, Board, bool]: | |
''' | |
Updates the board given a valid action of the current player. | |
Args: | |
action: A valid action in a form of cell coordinates. | |
Returns: | |
int: Reward for the first player. | |
Board: Resulting state. | |
bool: True if the game is over, False otherwise. | |
''' | |
over, _ = self.board.is_over() | |
assert (self.board.is_possible(action)) | |
assert (not over) | |
self.board.take_turn(action) | |
over, winner = self.board.is_over() | |
return winner, self.board, over | |
def __repr__(self): | |
''' | |
Returns current board state using a human-readable string | |
representation. | |
''' | |
return self.board.__repr__() | |
def reset(self): | |
''' | |
Empties the board and starts a new game. | |
''' | |
self.__init__() | |
class TDAgent(object): | |
''' | |
Tic Tac Toe-specific Temporal Difference [TD(0)] agent implementation. | |
TODO(omtcvxyz): Allow saving and loading value estimates to omit training | |
and skip to the interactive session for simplicity. | |
''' | |
def __init__(self, | |
environment: TicTacToe, | |
learning_rate: float = 0.1, | |
exploration_rate: float = 0.1) -> None: | |
self.environment: TicTacToe = environment | |
self.learning_rate: float = learning_rate | |
self.exploration_rate: float = exploration_rate | |
''' | |
TODO(omtcvxyz): Use get_all_states() to allocate memory only for the | |
possible states instead of taking space for all possible combinations | |
of 9 integers within [0; 2] range. Given that the terminal states are | |
known beforehand, these should be also marked beforehand. | |
''' | |
self.value: numpy.array = numpy.zeros(3 ** | |
(self.environment.board.size ** 2 + 1)) | |
def reset_exploration_rate(self): | |
''' | |
Sets exploration rate to 0. This is useful whenever one would like to | |
evaluate the agent's performance. | |
''' | |
self.exploration_rate = 0 | |
def consume_experience(self, initial_state: int, reward: int, | |
resulting_state: int, terminal: bool): | |
''' | |
This code uses formulation from the RL Book, Chapter 1. Although in | |
general the TD update rule looks like this: | |
V(S) = V(S) + \alpha * [R_t + V(S') - V(S)] | |
The environment only samples reward on episode completion and hence the | |
value function of terminal states could be set to the sampled reward, | |
which would produce the following update rule (as proposed in Chapter 1 | |
of the RL Book): | |
V(S) = V(S) + \alpha * [V(S') - V(S)] | |
Which is exactly the same as the one used before if we augment it with | |
the prior knowledge of the Tic Tac Toe environment. | |
''' | |
if terminal: | |
self.value[resulting_state] = reward | |
self.value[initial_state] += self.learning_rate * ( | |
self.value[resulting_state] - self.value[initial_state]) | |
def sample_action( | |
self) -> typing.Tuple[typing.Tuple[int, int], bool]: | |
''' | |
Outputs an action leading to the state with the greatest value with | |
probability 1 - self.exploration_rate. Samples random valid action | |
with probability self.exploration_rate. | |
Returns: | |
(int, int): Sampled action. | |
bool: True if the sampled action is a result of a "greedy" | |
transition, i.e. whether sampled action is not exploratory. | |
''' | |
possible_actions = self.environment.board.possible_actions() | |
if numpy.random.binomial(1, self.exploration_rate): | |
random_index = numpy.random.randint(0, len(possible_actions)) | |
return tuple(possible_actions[random_index]), False | |
board_copies = [ | |
copy.deepcopy(self.environment.board) for _ in possible_actions | |
] | |
for action, board in zip(possible_actions, board_copies): | |
board.take_turn(tuple(action)) | |
hashes = [board.hash() for board in board_copies] | |
best_state = numpy.argmax(self.value[hashes]) | |
return tuple(possible_actions[best_state]), True | |
def learn(episodes_count: int, learning_rate: float, | |
verbose: bool) -> typing.Tuple[TDAgent, TDAgent]: | |
''' | |
Feeds experience generated during Tic Tac Toe games between two similar | |
TD(0) agents to these agents while improving their policies. | |
Args: | |
episodes_count: Samples experience from episodes_count episodes. The | |
more experience the agents have, the better learned policies are. | |
Approximate rate of running simulations is ~400 games / second | |
(16 Gb RAM, Intel Core i7 processor setup). | |
learning_rate: Refers to \alpha TD(0) algorithm hyperparameter. The | |
more the faster learning process is, but it also becomes less | |
"sensetive". Optimally, learning_rate should slowly decrease to a | |
very small value over time. | |
verbose: Indicates whether progress is shown. tqdm is used for | |
convenient terminal experience. | |
Returns: | |
(TDAgent, TDAgent): Temporal Difference agents trained to play as the | |
first and the second player respectively. | |
''' | |
if verbose: | |
print('Training Temporal Difference AI.') | |
environment: TicTacToe = TicTacToe() | |
first_player: TDAgent = TDAgent(environment, learning_rate) | |
second_player: TDAgent = TDAgent(environment, learning_rate) | |
episodes = range(episodes_count) | |
if verbose: | |
episodes = tqdm.tqdm(episodes) | |
for episode in episodes: | |
first_player_turn: bool = True | |
while True: | |
if first_player_turn: | |
action, greedy = first_player.sample_action() | |
else: | |
action, greedy = second_player.sample_action() | |
first_player_turn = not first_player_turn | |
previous_state = environment.board.hash() | |
reward, _, over = environment.step(action) | |
current_state = environment.board.hash() | |
# Don't perform in case the last transition was exploratory. | |
if greedy: | |
first_player.consume_experience(previous_state, reward, | |
current_state, over) | |
# Second player consumes inverted reward, because it is sampled | |
# for the first player. | |
second_player.consume_experience(previous_state, -reward, | |
current_state, over) | |
if over: | |
environment.reset() | |
break | |
return first_player, second_player | |
def launch_interactive_session(AI: TDAgent, take_first_turn: bool): | |
''' | |
Launches continuous interactive session, in which human player can | |
challenge an Reinforcement Learning agent previously trained using | |
self-play. | |
Args: | |
TDAgent: The Reinforcement Learning agent, which faces the human | |
player. | |
take_first_turn: If True the human player will always take the first | |
turn, the AI will always take the first turn otherwise. | |
''' | |
AI.reset_exploration_rate() | |
environment: TicTacToe = AI.environment | |
environment.reset() | |
while True: | |
print('Playing against AI') | |
human_turn = take_first_turn | |
print(environment.board) | |
while True: | |
if human_turn: | |
print('Type coordinates (pair of 0-based space-separated ' | |
'integers) of the cell you would like to take:') | |
while True: | |
try: | |
x, y = map(int, input().split()) | |
action: tuple = (x, y) | |
reward, _, over = environment.step(action) | |
except: | |
print('Sorry, the input is invalid. Try again.') | |
continue | |
else: | |
break | |
print() | |
else: | |
action, greedy = AI.sample_action() | |
# Learn while playing against human player. | |
previous_state = environment.board.hash() | |
reward, _, over = environment.step(action) | |
current_state = environment.board.hash() | |
if take_first_turn: | |
AI.consume_experience(previous_state, -reward, | |
current_state, over) | |
else: | |
AI.consume_experience(previous_state, reward, | |
current_state, over) | |
human_turn = not human_turn | |
print(environment.board) | |
if over: | |
if (reward == 1 | |
and take_first_turn) or (reward == -1 | |
and not take_first_turn): | |
print('You won! Congratulations!') | |
elif (reward == -1 | |
and take_first_turn) or (reward == 1 | |
and not take_first_turn): | |
print('The AI won! Try again!') | |
else: | |
print('It\'s a draw!') | |
environment.reset() | |
break | |
print() | |
answer = input('Would you like to play another game? (y/N) ') | |
if answer.lower() != 'y' and answer.lower() != 'yes': | |
break | |
def main(): | |
parser: argparse.ArgumentParser = argparse.ArgumentParser( | |
description='''This script implements Temporal Difference agent for | |
the classic Tic Tac Toe environment and learns a policy by playing | |
against itself. A human player can play against the trained agent | |
upon the training completion. TD(0) parameters can be changed via | |
command line arguments and options.''') | |
parser.add_argument( | |
'-v', | |
'--verbose', | |
help='increase output verbosity', | |
action='store_true') | |
parser.add_argument( | |
'--learning_rate', | |
default=0.1, | |
type=float, | |
help='step-size parameter (alpha); passed value ' | |
'should be within (0, 1] range [defaults to 0.1]') | |
parser.add_argument( | |
'--no-interactive', | |
action='store_true', | |
help='unless this option is passed, interactive session against AI ' | |
'is launched after learning the policy') | |
parser.add_argument( | |
'--seed', | |
default=42, | |
type=int, | |
help='fix numpy random seed for reproducibility [defaults to 42]') | |
parser.add_argument( | |
'--episodes', | |
default=20000, | |
type=int, | |
help='train temporal difference AI agent for EPISODES games [defaults ' | |
'to 20000]') | |
parser.add_argument( | |
'--take_first_turn', | |
action='store_true', | |
help='always take the first turn; unless passed, human player will ' | |
'always take the second turn') | |
arguments: argparse.Namespace = parser.parse_args() | |
assert (0 < arguments.learning_rate and arguments.learning_rate <= 1) | |
numpy.random.seed(arguments.seed) | |
first_turn_AI, second_turn_AI = learn( | |
arguments.episodes, arguments.learning_rate, arguments.verbose) | |
if arguments.take_first_turn: | |
AI = second_turn_AI | |
else: | |
AI = first_turn_AI | |
if not arguments.no_interactive: | |
launch_interactive_session(AI, arguments.take_first_turn) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Author: Kirill Bobyrev (https://github.com/kirillbobyrev) | |
Unit tests for Tic Tac Toe and the TD Agent implementation. | |
''' | |
import copy | |
import numpy | |
import tic_tac_toe | |
import unittest | |
class BoardTest(unittest.TestCase): | |
def setUp(self): | |
self.board = tic_tac_toe.Board() | |
def test_construction(self): | |
for row in self.board.cells: | |
for cell in row: | |
self.assertEqual(cell, 0) | |
def test_take_turn(self): | |
self.board.take_turn((0, 0)) | |
self.assertEqual(self.board.cells[0][0], 1) | |
self.board.take_turn((1, 1)) | |
self.assertEqual(self.board.cells[1][1], -1) | |
self.board.take_turn((2, 2)) | |
self.assertEqual(self.board.cells[2][2], 1) | |
def test_is_possible(self): | |
self.assertEqual(self.board.is_possible((0, 0)), True) | |
self.board.take_turn((0, 0)) | |
self.assertEqual(self.board.is_possible((0, 0)), False) | |
self.assertEqual(self.board.is_possible((1, 1)), True) | |
self.board.take_turn((1, 1)) | |
self.assertEqual(self.board.is_possible((1, 1)), False) | |
def test_possible_actions(self): | |
cells_to_take = ((0, 0), (2, 0), (2, 2), (1, 2), (2, 1), (1, 1)) | |
taken_cells = set() | |
for cell in cells_to_take: | |
self.board.take_turn(cell) | |
taken_cells.add(cell) | |
self.assertEqual(self.board.possible_actions().all(), | |
numpy.array([(i, j) | |
for i in range(3) for j in range(3) | |
if (i, j) not in taken_cells]).all()) | |
def test_is_over(self): | |
# Check horizontal combination. | |
over, winner = self.board.is_over() | |
self.assertEqual(over, False) | |
self.board.take_turn((2, 0)) | |
self.board.take_turn((1, 0)) | |
self.board.take_turn((2, 1)) | |
over, winner = self.board.is_over() | |
self.assertEqual(over, False) | |
self.board.take_turn((1, 1)) | |
self.board.take_turn((2, 2)) | |
over, winner = self.board.is_over() | |
self.assertEqual(over, True) | |
self.assertEqual(winner, 1) | |
# Reset the board and check vertical combination. | |
self.setUp() | |
over, winner = self.board.is_over() | |
self.assertEqual(over, False) | |
self.board.take_turn((0, 1)) | |
self.board.take_turn((1, 0)) | |
self.board.take_turn((2, 2)) | |
self.board.take_turn((0, 0)) | |
self.board.take_turn((0, 2)) | |
self.board.take_turn((2, 0)) | |
over, winner = self.board.is_over() | |
self.assertEqual(over, True) | |
self.assertEqual(winner, -1) | |
# Reset the board and check diagonal combination (left top to right | |
# bottom). | |
self.setUp() | |
over, winner = self.board.is_over() | |
self.assertEqual(over, False) | |
self.board.take_turn((0, 0)) | |
self.board.take_turn((1, 0)) | |
self.board.take_turn((1, 1)) | |
self.board.take_turn((0, 1)) | |
self.board.take_turn((2, 2)) | |
over, winner = self.board.is_over() | |
self.assertEqual(over, True) | |
self.assertEqual(winner, 1) | |
# Reset the board and check diagonal combination (left bottom to right | |
# top). | |
self.setUp() | |
over, winner = self.board.is_over() | |
self.assertEqual(over, False) | |
self.board.take_turn((2, 0)) | |
self.board.take_turn((1, 0)) | |
self.board.take_turn((1, 1)) | |
self.board.take_turn((0, 1)) | |
self.board.take_turn((0, 2)) | |
over, winner = self.board.is_over() | |
self.assertEqual(over, True) | |
self.assertEqual(winner, 1) | |
def test_hash(self): | |
self.assertEqual(self.board.hash(), 0) | |
self.board.take_turn((2, 2)) | |
self.assertEqual(self.board.hash(), 1) | |
self.board.take_turn((2, 1)) | |
self.assertEqual(self.board.hash(), 1 + 3 * 2) | |
class TicTacToe(unittest.TestCase): | |
def setUp(self): | |
self.tic_tac_toe = tic_tac_toe.TicTacToe() | |
def test_step(self): | |
reward, _, over = self.tic_tac_toe.step((0, 0)) | |
self.assertEqual(reward, 0) | |
self.assertEqual(over, False) | |
reward, _, over = self.tic_tac_toe.step((1, 0)) | |
self.assertEqual(reward, 0) | |
self.assertEqual(over, False) | |
reward, _, over = self.tic_tac_toe.step((0, 1)) | |
self.assertEqual(reward, 0) | |
self.assertEqual(over, False) | |
reward, _, over = self.tic_tac_toe.step((1, 1)) | |
self.assertEqual(reward, 0) | |
self.assertEqual(over, False) | |
reward, _, over = self.tic_tac_toe.step((0, 2)) | |
self.assertEqual(reward, 1) | |
self.assertEqual(over, True) | |
def test_reset(self): | |
for row in self.tic_tac_toe.board.cells: | |
for cell in row: | |
self.assertEqual(cell, 0) | |
self.tic_tac_toe.step((0, 0)) | |
self.assertEqual(self.tic_tac_toe.board.cells[0][0], 1) | |
self.tic_tac_toe.reset() | |
for row in self.tic_tac_toe.board.cells: | |
for cell in row: | |
self.assertEqual(cell, 0) | |
class TDAgent(unittest.TestCase): | |
def setUp(self): | |
self.environment = tic_tac_toe.TicTacToe() | |
self.agent = tic_tac_toe.TDAgent(self.environment) | |
def test_consume_experience(self): | |
self.environment.step((0, 0)) | |
self.environment.step((1, 0)) | |
previous_state = self.environment.board.hash() | |
previous_value = self.agent.value[previous_state] | |
reward, _, over = self.environment.step((0, 1)) | |
current_state = self.environment.board.hash() | |
current_value = self.agent.value[current_state] | |
# Ensure the last turn is not the winning one. | |
self.assertEqual(over, False) | |
self.assertEqual(reward, 0) | |
# Test agent's Value function update. | |
self.agent.consume_experience(previous_state, reward, current_state, | |
over) | |
updated_value = self.agent.value[previous_state] | |
self.assertEqual(updated_value, | |
previous_value + self.agent.learning_rate * | |
(current_value - previous_value)) | |
# Play the game to the end. | |
self.environment.step((2, 0)) | |
previous_state = self.environment.board.hash() | |
previous_value = self.agent.value[previous_state] | |
reward, _, over = self.environment.step((0, 2)) | |
current_state = self.environment.board.hash() | |
nearby_value = self.agent.value[current_state] | |
# Ensure the last turn was indeed the winning one. | |
self.assertEqual(over, True) | |
self.assertEqual(reward, 1) | |
# Test agent's Value function update. | |
self.agent.consume_experience(previous_state, reward, current_state, | |
over) | |
updated_value = self.agent.value[previous_state] | |
self.assertEqual(self.agent.value[current_state], reward) | |
self.assertEqual(updated_value, | |
previous_value + self.agent.learning_rate * | |
(reward - previous_value)) | |
def test_sample_action(self): | |
self.agent.reset_exploration_rate() | |
possible_actions = self.environment.board.possible_actions() | |
boards = [ | |
copy.deepcopy(self.environment.board) for _ in possible_actions | |
] | |
for action, board in zip(possible_actions, boards): | |
board.take_turn(tuple(action)) | |
action, greedy = self.agent.sample_action() | |
self.environment.board.take_turn(action) | |
for board in boards: | |
self.assertLessEqual( | |
self.agent.value[board.hash()], | |
self.agent.value[self.environment.board.hash()]) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment