Skip to content

Instantly share code, notes, and snippets.

@saolsen
Created July 11, 2023 19:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saolsen/57cf8b90a17812cf02ae46cee9c31b97 to your computer and use it in GitHub Desktop.
Save saolsen/57cf8b90a17812cf02ae46cee9c31b97 to your computer and use it in GitHub Desktop.
connect4 monte-carlo tree search with rayon
use rayon::prelude::*;
use thiserror::Error;
const ROWS: usize = 6;
const COLS: usize = 7;
#[derive(Debug)]
pub struct Connect4Action {
pub column: usize,
}
#[derive(Debug, Clone)]
pub struct Connect4State {
pub board: Vec<Option<usize>>,
pub next_player: usize,
}
impl Default for Connect4State {
fn default() -> Self {
Self {
board: vec![None; ROWS * COLS],
next_player: 0,
}
}
}
#[derive(Debug)]
pub enum Connect4Result {
Winner(usize),
Tie,
}
#[derive(Debug)]
pub enum Connect4Check {
InProgress,
Over(Connect4Result),
}
#[allow(clippy::identity_op)]
pub fn check_state(state: &Connect4State) -> Connect4Check {
use Connect4Check::*;
use Connect4Result::*;
// Check vertical wins
for col in 0..COLS {
for row in 0..3 {
match (
state.board[col * ROWS + row + 0],
state.board[col * ROWS + row + 1],
state.board[col * ROWS + row + 2],
state.board[col * ROWS + row + 3],
) {
(Some(i), Some(j), Some(k), Some(l)) if i == j && j == k && k == l => {
return Over(Winner(i))
}
_ => (),
}
}
}
// Check horizontal wins
for row in 0..ROWS {
for col in 0..4 {
match (
state.board[(col + 0) * ROWS + row],
state.board[(col + 1) * ROWS + row],
state.board[(col + 2) * ROWS + row],
state.board[(col + 3) * ROWS + row],
) {
(Some(i), Some(j), Some(k), Some(l)) if i == j && j == k && k == l => {
return Over(Winner(i))
}
_ => (),
}
}
}
// Check diagonal up wins
for col in 0..4 {
for row in 0..3 {
match (
state.board[(col + 0) * ROWS + row + 0],
state.board[(col + 1) * ROWS + row + 1],
state.board[(col + 2) * ROWS + row + 2],
state.board[(col + 3) * ROWS + row + 3],
) {
(Some(i), Some(j), Some(k), Some(l)) if i == j && j == k && k == l => {
return Over(Winner(i))
}
_ => (),
}
}
}
// Check diagonal down wins
for col in 0..4 {
for row in 3..6 {
match (
state.board[(col + 0) * ROWS + row - 0],
state.board[(col + 1) * ROWS + row - 1],
state.board[(col + 2) * ROWS + row - 2],
state.board[(col + 3) * ROWS + row - 3],
) {
(Some(i), Some(j), Some(k), Some(l)) if i == j && j == k && k == l => {
return Over(Winner(i))
}
_ => (),
}
}
}
// Check for tie
for col in 0..COLS {
if state.board[col * ROWS + ROWS - 1].is_none() {
return InProgress;
}
}
Over(Tie)
}
#[derive(Error, Debug)]
pub enum ActionError {
#[error("Column must be between 0 and 6. Got `{0}`.")]
UnknownColumn(usize),
#[error("Column `{0}` is full.")]
FullColumn(usize),
}
pub fn check_action(state: &Connect4State, action: &Connect4Action) -> bool {
if action.column >= COLS {
return false;
}
state.board[action.column * ROWS + ROWS - 1].is_none()
}
pub fn apply_action(
state: &mut Connect4State,
action: &Connect4Action,
) -> Result<Connect4Check, ActionError> {
use ActionError::*;
if action.column >= COLS {
return Err(UnknownColumn(action.column));
}
for row in 0..ROWS {
let cell = &mut state.board[action.column * ROWS + row];
if cell.is_none() {
*cell = Some(state.next_player);
break;
}
if row == ROWS - 1 {
return Err(FullColumn(action.column));
}
}
state.next_player = 1 - state.next_player;
Ok(check_state(state))
}
fn play(
state: &mut Connect4State,
blue_agent: fn(&Connect4State) -> Connect4Action,
red_agent: fn(&Connect4State) -> Connect4Action,
) -> Result<Connect4Result, ActionError> {
loop {
let action = if state.next_player == 0 {
blue_agent(state)
} else {
red_agent(state)
};
apply_action(state, &action)?;
if let Connect4Check::Over(result) = check_state(state) {
return Ok(result);
}
}
}
fn rand_agent(state: &Connect4State) -> Connect4Action {
use rand::Rng;
let mut rng = rand::thread_rng();
loop {
// Generate random actions until one is valid.
let action = Connect4Action {
column: rng.gen_range(0..COLS),
};
if check_action(state, &action) {
return action;
}
}
}
fn mcts_agent(state: &Connect4State) -> Connect4Action {
// For each possible action, take the action and then simulate multiple random games from that
// state.
// Keep track of the number of wins for each action.
// Pick the action with the highest win rate.
let player = state.next_player;
(0..COLS)
.into_par_iter()
.map(|col| Connect4Action { column: col })
.filter(|action| check_action(state, action))
.map(|action| {
let mut next_state = state.clone();
apply_action(&mut next_state, &action).unwrap();
// Simulate 100 games from this action.
let score = (0..100)
.into_par_iter()
.map(move |_| {
let mut state = next_state.clone();
match play(&mut state, rand_agent, rand_agent).unwrap() {
Connect4Result::Winner(winner) => {
if winner == player {
1
} else {
-1
}
}
Connect4Result::Tie => 0,
}
})
.sum::<i32>() as f32
/ 100.;
(action, score)
})
// Pick the action with the highest score.
.max_by(|(_, score1), (_, score2)| score1.partial_cmp(score2).unwrap())
.map(|(action, _)| action)
.unwrap()
}
// 0.14s for 10 release
// 0.24s for 10 release with rayon... slower...
// 2.26s for 100 with rayon just the 0..100 loop
// 1.82s for 100 with rayon everything, kewl
// 1.8 if I make the main par too, obv output out of order now
fn main() {
(0..100).into_par_iter().for_each(|i| {
let mut state = Connect4State::default();
let result = play(&mut state, rand_agent, mcts_agent).unwrap();
println!("Game {}: {:?}", i, result);
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment