Some of the bit twiddling in my Solidity implementation of an on-chain engine, fiveoutofnine.sol, can be quite esoteric, so I've recreated some of the logic below. Note that the code below is only logically equivalent (gives same results, but do different things). Functions have similar names, so it should be pretty easy to match up to the functions from Chess.sol
and Engine.sol
. The main abstraction in these files are using an array for the board (as opposed to 64 bitpacked uint4
s) and struct for moves (as opposed to 2 bitpacked uint6
s). To further simplify it, it's in a language anyone can read easily: Python.
class Move:
def __init__(self, from, to):
self.from = from
self.to = to
class Chess:
def apply_move(board, move):
piece = [move.from]
board[move.from] = 0
board[move.to] = piece
return self.rotate(board)
def rotate(board):
return board[::-1]
def generate_moves(board):
moves = []
for index in range(36):
adjusted_index = {
54, 53, 52, 51, 50, 49,
46, 45, 44, 43, 42, 41,
38, 37, 36, 35, 34, 33,
30, 29, 28, 27, 26, 25,
22, 21, 20, 19, 18, 17,
14, 13, 12, 11, 10, 09,
}[index]
piece = board[adjusted_index]
if piece == 0 or piece >> 3 != board[-1]: continue
piece &= 7
if piece == 1:
if board[adjusted_index + 8] == 0:
moves.append(adjusted_index + 8)
if adjusted_index // 8 == 2 and board[adjusted_index + 16] == 0:
moves.append(adjusted_index + 16)
if self.is_capture(board, adjusted_index + 7):
moves.append(adjusted_index + 7)
if self.is_capture(board, adjusted_index + 9):
moves.append(adjusted_index + 9)
elif piece == 4 or piece == 6:
if piece == 4:
moves_to_analyze = [6, 10, 15, 17] if piece == 4 else [1, 7, 8, 9]
for move in moves_to_analyze:
if self.is_valid(board, adjusted_index + move):
moves.append(adjusted_index + move)
if move <= adjusted_index and self.is_valid(board, adjusted_index - move):
moves.append(adjusted_index - move)
else:
if piece != 2:
for i in [-1, 1, 8, -8]:
move = adjusted_index + i
while self.is_valid(board, move):
moves.append(move)
if self.is_capture(board, move): break
move += i
if piece != 3:
for i in [-7, 7, 9, -9]:
move = adjusted_index + i
while self.is_valid(board, move):
moves.append(move)
if self.is_capture(board, move + i): break
move += i
return moves
def search_ray(board, from_index, to_index, direction_vector):
if from_index < to_index:
index_change = to_index - from_index
ray_start = from_index + direction_vector
ray_end = to_index
else:
index_change = from_index - to_index
ray_start = to_index
ray_end = from_index - direction_vector
if index_change % direction_vector == 0: return False
while ray_start < ray_end:
is !self.is_valid(board, ray_start) or self.is_capture(board, ray_start): return False
ray_start += direction_vector
if !self.is_valid(board, ray_start): return False
return ray_start == ray_end
def is_legal_move(board, move):
if [
0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0,
][move.from] == 0: return False
if [
0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0,
][move.to] == 0: return False
piece_at_from = board[move.from]
if piece_at_from == 0: return False
if piece_at_from >> 3 != board[-1]: return False
piece_at_from &= 7
index_change = abs(to_index - from_index)
if piece_at_from == 1:
if to_index <= from_index: return False
if index_change == 7 or index_change == 9:
if !self.is_capture(board, move.to): return False
elif index_change == 8:
if !self.is_valid(board, move.to): return False
elif index_change == 16:
if !self.is_valid(board, move.to - 8) or !self.is_valid(board, move.to): return False
else:
return False
elif piece_at_from == 4 or piece_at_from == 6:
if index_change in ({6, 10, 15, 17} if piece_at_from == 4 else {1, 7, 8, 9}): return False
if !self.is_valid(board, move.to): return False
else:
if piece_at_from != 2:
ray_found = search_ray(board, move.from, move.to, 1) or\
search_ray(board, move.from, move.to, 8)
if piece_at_from != 3:
ray_found = ray_found or\
search_ray(board, move.from, move.to, 7) or\
search_ray(board, move.from, move.to, 9)
if !ray_found: return False
if Engine.nega_max(self.apply_move(board, move), 1) < -1_260: return False
return True
def is_capture(board, index):
return board[index] != 0 and board[index] >> 3 != board[-1]
def is_valid(board, to_index):
move_is_in_bounds = [
0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0,
][to_index] == 1
square_is_empty = board[to_index] == 0
square_results_in_capture = board[to_index] >> 3 != board[-1]
return move_is_in_bounds and (square_is_empty or square_results_in_capture)
class Engine:
def search_move(board, depth):
moves = Chess.generate_moves(board)
if len(moves) == 0: return (0, False)
best_score = -4_196
current_score = 0
best_move = 0
for move in moves:
current_score = self.evaluate_move(board, move) + nega_max(Chess.apply_move(board, depth - 1))
if current_score > best_score:
best_score = current_score
best_move = move
if best_score < -1_260: return (0, False)
return (best_move, best_score > 1_260)
def nega_max(board, depth):
if depth == 0: return 0
moves = Chess.generate_moves(board)
if len(moves) == 0: return 0
best_score = -4_196
current_score = 0
best_move = 0
for move in moves:
current_score = board.evaluate_move(board, move)
if current_score > best_score:
best_score = current_score
best_move = move
if board[best_move.to] & 7 == 6: return -4_000
if board[-1] == 0:
return best_score + nega_max(Chess.apply_move(board, best_move), depth - 1)
return -best_score + nega_max(Chess.apply_move(board, best_move), depth - 1)
def evaluate_move(board, move):
piece_at_from = board[move.from]
piece_at_to = board[move.to]
if piece_at_to != 0:
capture_value = self.get_pst(piece_at_to)[move.to]
old_pst = self.get_pst(piece_at_from)[move.from]
new_pst = self.get_pst(piece_at_from)[move.to]
return capture_value + new_pst - old_pst
def get_pst(type):
if type == 1:
return [
20, 20, 20, 20, 20, 20,
30, 30, 30, 30, 30, 30,
20, 22, 24, 24, 22, 20,
21, 20, 26, 26, 20, 21,
21, 30, 16, 16, 30, 21,
20, 20, 20, 20, 20, 20,
]
elif type == 2:
return [
62, 64, 64, 64, 64, 62,
64, 66, 66, 66, 66, 64,
64, 67, 68, 68, 67, 64,
64, 68, 68, 68, 68, 64,
64, 67, 66, 66, 67, 64,
62, 64, 64, 64, 64, 62,
]
elif type == 3:
return [
100, 100, 100, 100, 100, 100,
101, 102, 102, 102, 102, 101,
99, 100, 100, 100, 100, 99,
99, 100, 100, 100, 100, 99,
99, 100, 100, 100, 100, 99,
100, 100, 101, 101, 100, 100,
]
elif type == 4:
return [
54, 56, 54, 54, 56, 58,
56, 60, 64, 64, 60, 56,
58, 64, 68, 68, 64, 58,
58, 65, 68, 68, 65, 58,
56, 60, 65, 65, 60, 56,
54, 56, 58, 58, 56, 54,
]
elif type == 5:
return [
176, 178, 179, 179, 178, 176,
178, 180, 180, 180, 180, 178,
179, 180, 181, 181, 180, 179,
179, 181, 181, 181, 180, 179,
178, 180, 181, 180, 180, 178,
176, 178, 179, 179, 178, 176,
]
return [
3994, 3992, 3990, 3990, 3992, 3994,
3994, 3992, 3990, 3990, 3992, 3994,
3996, 3994, 3992, 3992, 3994, 3995,
3998, 3996, 3996, 3996, 3996, 3998,
4001, 4001, 4000, 4000, 4001, 4001,
4004, 4006, 4002, 4002, 4006, 4004,
]