Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
CUDA Othello Solver
// CUDA Othello Solver
// License: GPL 3
#include <cstdio>
#include <cassert>
constexpr int threadsPerBlock = 128;
constexpr int simdWidth = 4;
constexpr int nodesPerBlock = threadsPerBlock/simdWidth;
constexpr int MAX_DEPTH = 10;
using ull = unsigned long long;
struct Board {
ull player;
ull opponent;
__host__ __device__ Board() = default;
__host__ __device__ Board(ull p, ull o) : player(p), opponent(o) {}
__host__ __device__ Board(const Board &) = default;
__host__ __device__ Board(Board &&) = default;
__host__ __device__ Board &operator=(const Board &) = default;
__host__ __device__ Board &operator=(Board &&) = default;
};
struct Node {
Board bd;
ull puttable;
char alpha;
char beta;
bool pass;
bool passed;
__device__ void update(char value) {
alpha = max(alpha, -value);
}
};
__shared__ Node nodes_stack[nodesPerBlock][MAX_DEPTH+1];
__shared__ int count[threadsPerBlock];
__device__ int score(const Board &bd) {
int pnum = __popcll(bd.player);
int onum = __popcll(bd.opponent);
if (pnum == onum) return 0;
if (pnum > onum) return 64 - 2*onum;
else return 2*pnum - 64;
}
__constant__ ull mask1[4] = {
0x0080808080808080ULL,
0x7f00000000000000ULL,
0x0102040810204000ULL,
0x0040201008040201ULL
};
__constant__ ull mask2[4] = {
0x0101010101010100ULL,
0x00000000000000feULL,
0x0002040810204080ULL,
0x8040201008040200ULL
};
__device__ ull flip(const Board &bd, int pos, int index) {
ull om = bd.opponent;
if (index) om &= 0x7E7E7E7E7E7E7E7EULL;
ull mask = mask1[index] >> (63 - pos);
ull outflank = (0x8000000000000000ULL >> __clzll(~om & mask)) & bd.player;
ull flipped = (-outflank * 2) & mask;
mask = mask2[index] << pos;
outflank = mask & ((om | ~mask) + 1) & bd.player;
flipped |= (outflank - (outflank != 0)) & mask;
return flipped;
}
__device__ ull flip_all(const Board &bd, int pos) {
return flip(bd, pos, 0) | flip(bd, pos, 1) | flip(bd, pos, 2) | flip(bd, pos, 3);
}
struct Arrays {
const Board *bd_ary;
int *res_ary;
int *nodes_count;
size_t index;
const size_t size;
};
__device__ bool get_next_node(Arrays &arys, const int node_index, const int simd_index, const int res) {
if (simd_index == 0) {
arys.res_ary[arys.index] = res;
arys.nodes_count[arys.index] = count[threadIdx.x];
}
arys.index += (blockDim.x * gridDim.x) / simdWidth;
if (arys.index >= arys.size) return true;
count[threadIdx.x] = 1;
if (simd_index == 0) {
Node &root = nodes_stack[node_index][0];
root.bd = arys.bd_ary[arys.index];
root.puttable = ~(root.bd.player | root.bd.opponent);
root.alpha = -64;
root.beta = 64;
root.pass = true;
root.passed = false;
}
return false;
}
__device__ void alpha_beta(Arrays &arys) {
int node_index = threadIdx.x / simdWidth;
int simd_index = threadIdx.x % simdWidth;
int stack_index = 0;
while (true) {
Node &node = nodes_stack[node_index][stack_index];
if (node.puttable == 0) {
if (node.pass) {
if (node.passed) {
if (stack_index) {
Node &parent = nodes_stack[node_index][stack_index-1];
if (simd_index == 0) {
parent.update(-score(node.bd));
}
--stack_index;
} else {
if (get_next_node(arys, node_index, simd_index, -score(node.bd)))
return;
}
} else {
if (simd_index == 0) {
node.bd = Board(node.bd.opponent, node.bd.player);
node.puttable = ~(node.bd.player | node.bd.opponent);
int tmp = node.alpha;
node.alpha = -node.beta;
node.beta = -tmp;
node.passed = true;
}
}
} else {
if (stack_index) {
Node &parent = nodes_stack[node_index][stack_index-1];
if (simd_index == 0) {
parent.update((node.passed ? -1 : 1) * node.alpha);
}
--stack_index;
} else {
if (get_next_node(arys, node_index, simd_index, (node.passed ? -1 : 1) * node.alpha))
return;
}
}
} else if (node.alpha >= node.beta) {
if (stack_index) {
Node &parent = nodes_stack[node_index][stack_index-1];
if (simd_index == 0) {
parent.update((node.passed ? -1 : 1) * node.alpha);
}
--stack_index;
} else {
if (get_next_node(arys, node_index, simd_index, (node.passed ? -1 : 1) * node.alpha))
return;
}
} else {
ull bit = node.puttable & -node.puttable;
if (simd_index == 0) {
node.puttable ^= bit;
}
int pos = __popcll(bit-1);
ull flipped = flip(node.bd, pos, simd_index);
flipped |= __shfl_xor(flipped, 1);
flipped |= __shfl_xor(flipped, 2);
if (flipped) {
++stack_index;
if (simd_index == 0) {
Node &next = nodes_stack[node_index][stack_index];
node.pass = false;
next.bd = Board(node.bd.opponent ^ flipped, (node.bd.player ^ flipped) | bit);
next.puttable = ~(next.bd.player | next.bd.opponent);
next.alpha = -node.beta;
next.beta = -node.alpha;
next.pass = true;
next.passed = false;
++count[threadIdx.x];
}
}
}
}
}
__global__ void search_noordering(const Board *bd_ary, int *res_ary, int *nodes_count, const size_t size) {
size_t index = (threadIdx.x + blockIdx.x * blockDim.x) / simdWidth;
int simd_index = threadIdx.x % simdWidth;
int node_index = threadIdx.x / simdWidth;
count[threadIdx.x] = 1;
if (simd_index == 0) {
Node &root = nodes_stack[node_index][0];
root.bd = bd_ary[index];
root.puttable = ~(root.bd.player | root.bd.opponent);
root.alpha = -64;
root.beta = 64;
root.pass = true;
root.passed = false;
}
Arrays arys = {
bd_ary,
res_ary,
nodes_count,
index,
size
};
alpha_beta(arys);
}
struct Base81 {
ull table_p[256];
ull table_o[256];
Base81() {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 3; ++k) {
for (int l = 0; l < 3; ++l) {
int index = i + 3*j + 9*k + 32*l + 33;
table_p[index] = (i==1) + 2*(j==1) + 4*(k==1) + 8*(l==1);
table_o[index] = (i/2) + 2*(j/2) + 4*(k/2) + 8*(l/2);
}
}
}
}
}
__host__ void output(const Board &bd) {
for (int i = 0; i < 64; ++i) {
if ((bd.player >> i) & 1) {
fputc('x', stderr);
} else if ((bd.opponent >> i) & 1) {
fputc('o', stderr);
} else {
fputc('.', stderr);
}
if ((i % 8) == 7) fputc('\n', stderr);
}
fputc('\n', stderr);
}
__host__ Board input_bd(FILE *fp, int l) {
char buf[18];
if (fgets(buf, 18, fp) == NULL) {
fprintf(stderr, "unexpected EOF\n");
exit(-1);
}
ull p = 0, o = 0;
if (strlen(buf) < 16) {
fprintf(stderr, "too short input: line %d\n", l+1);
exit(-1);
}
for (int i = 0; i < 16; ++i) {
p |= table_p[buf[i]] << (i*4);
o |= table_o[buf[i]] << (i*4);
}
assert(!(p & o));
return Board(p, o);
}
};
int main(int argc, char **argv) {
if (argc < 3) {
fprintf(stderr, "%s [INPUT] [OUTPUT]\n", argv[0]);
return EXIT_FAILURE;
}
FILE *fp = fopen(argv[1], "r");
if (fp == nullptr) {
fprintf(stderr, "no such file: %s\n", argv[1]);
return EXIT_FAILURE;
}
FILE *ofp = fopen(argv[2], "w");
if (fp == nullptr) {
fprintf(stderr, "cannot open file: %s\n", argv[2]);
return EXIT_FAILURE;
}
int n;
fscanf(fp, "%d ", &n);
fprintf(stderr, "start read, data size is %d\n", n);
Base81 b81;
Board *bd_ary_host = (Board*)malloc(sizeof(Board) * n);
for (int i = 0; i < n; ++i) {
bd_ary_host[i] = b81.input_bd(fp, i);
}
Board *bd_ary;
int *res_ary;
int *nodes_count;
cudaMalloc(&bd_ary, sizeof(Board) * n);
cudaMallocManaged(&res_ary, sizeof(int) * n);
cudaMallocManaged(&nodes_count, sizeof(int) * n);
cudaMemcpy(bd_ary, bd_ary_host, sizeof(Board) * n, cudaMemcpyHostToDevice);
cudaMemset(res_ary, 0, sizeof(int) * n);
cudaMemset(nodes_count, 0, sizeof(int) * n);
fputs("start solve\n", stderr);
search_noordering<<<256, threadsPerBlock>>>(bd_ary, res_ary, nodes_count, n);
cudaDeviceSynchronize();
fputs("end solve\n", stderr);
ull nodes_total = 0;
for (int i = 0; i < n; ++i) {
fprintf(ofp, "%d\n", res_ary[i]);
nodes_total += nodes_count[i];
}
fprintf(stderr, "%lu\n", nodes_total);
free(bd_ary_host);
cudaFree(bd_ary);
cudaFree(res_ary);
cudaFree(nodes_count);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment