Last active
March 27, 2021 05:49
-
-
Save TadaoYamaoka/46a0c0baed7bcafbb4963a6056df3426 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cfloat> | |
#include <chrono> | |
#include <immintrin.h> | |
inline unsigned int __builtin_ctz(unsigned int x) { unsigned long r; _BitScanForward(&r, x); return r; } | |
const __m256i m256i_zero{}; | |
const __m256i m256i_one = _mm256_set1_epi32(1); | |
const __m256 m256_one = _mm256_set1_ps(1); | |
const __m256i m256i_eight = _mm256_set1_epi32(8); | |
inline unsigned int SelectMaxUcbChild(const int child_num, float* win, int* move_count, float* nnrate) | |
{ | |
float c = 1.2f; | |
float parent_q = 0.55f; | |
float init_u = 0.66f; | |
float sqrt_sum = 1.5f; | |
__m256 m256_c = _mm256_broadcast_ss(&c); | |
__m256 m256_parent_q = _mm256_broadcast_ss(&parent_q); | |
__m256 m256_init_u = _mm256_broadcast_ss(&init_u); | |
__m256 m256_sqrt_sum = _mm256_broadcast_ss(&sqrt_sum); | |
// UCB値最大の手を求める | |
unsigned int max_child = 0; | |
//float max_value = -FLT_MAX; | |
__m256 vmaxvalue = _mm256_set1_ps(-FLT_MAX); | |
__m256i vnowposition = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); | |
__m256i vmaxposition = vnowposition; | |
for (size_t i = 0; i < child_num; i += 8) { | |
if (i + 8 > child_num) { | |
// 残り8未満 | |
__m256i mask_rest; | |
switch (child_num - i) { | |
case 1: | |
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); | |
break; | |
case 2: | |
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); | |
break; | |
case 3: | |
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); | |
break; | |
case 4: | |
mask_rest = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); | |
break; | |
case 5: | |
mask_rest = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); | |
break; | |
case 6: | |
mask_rest = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); | |
break; | |
case 7: | |
mask_rest = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); | |
break; | |
default: | |
// unreachable | |
mask_rest = _mm256_set1_epi32(0); | |
break; | |
} | |
__m256i m256i_move_count = _mm256_maskload_epi32(move_count + i, mask_rest); | |
__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero); | |
// q = (float)(win / move_count); | |
__m256 m256_win = _mm256_maskload_ps(win + i, mask_rest); | |
__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count); | |
__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count); | |
__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask)); | |
// u = sqrt_sum / (1 + move_count); | |
__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one); | |
__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1); | |
__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask)); | |
__m256 m256_rate = _mm256_maskload_ps(nnrate + i, mask_rest); | |
//const float ucb_value = q + c * u * rate; | |
__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u); | |
m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate); | |
m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value); | |
// mask | |
m256_ucb_value = _mm256_and_ps(m256_ucb_value, _mm256_castsi256_ps(mask_rest)); | |
// find max | |
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS); | |
vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue); | |
vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp)); | |
vnowposition = _mm256_add_epi32(vnowposition, m256i_eight); | |
break; | |
} | |
// 未実装 | |
//if (uct_child[i].IsWin()) { | |
// child_win_count++; | |
// // 負けが確定しているノードは選択しない | |
// continue; | |
//} | |
//else if (uct_child[i].IsLose()) { | |
// // 子ノードに一つでも負けがあれば、自ノードを勝ちにできる | |
// if (parent != nullptr) | |
// parent->SetWin(); | |
// // 勝ちが確定しているため、選択する | |
// return i; | |
//} | |
//const WinType win = uct_child[i].win; | |
//const int move_count = uct_child[i].move_count; | |
// | |
//if (move_count == 0) { | |
__m256i m256i_move_count = _mm256_load_si256((__m256i*)(move_count + i)); | |
__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero); | |
// // 未探索のノードの価値に、親ノードの価値を使用する | |
// q = parent_q; | |
// u = init_u; | |
// --> 下記のelseの計算結果と合わせて_mm256_blendv_psで設定する | |
//} | |
//else { | |
// q = (float)(win / move_count); | |
__m256 m256_win = _mm256_load_ps(win + i); | |
__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count); | |
__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count); | |
// ニュートン-ラフソン法で割り算を高速化 | |
//__m256 rc = _mm256_rcp_ps(m256_move_count); | |
//__m256 rcip = _mm256_sub_ps(_mm256_add_ps(rc, rc), _mm256_mul_ps(rc, _mm256_mul_ps(rc, m256_move_count))); | |
//__m256 m256_q_tmp = _mm256_mul_ps(m256_win, rcip); | |
__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask)); | |
// u = sqrt_sum / (1 + move_count); | |
__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one); | |
__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1); | |
// ニュートン-ラフソン法で割り算を高速化 | |
//rc = _mm256_rcp_ps(m256_move_count_plus1); | |
//rcip = _mm256_sub_ps(_mm256_add_ps(rc, rc), _mm256_mul_ps(rc, _mm256_mul_ps(rc, m256_move_count_plus1))); | |
//__m256 m256_u_tmp = _mm256_mul_ps(m256_sqrt_sum, rcip); | |
__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask)); | |
//} | |
//const float rate = uct_child[i].nnrate; | |
__m256 m256_rate = _mm256_load_ps(nnrate + i); | |
//const float ucb_value = q + c * u * rate; | |
__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u); | |
m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate); | |
m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value); | |
/*const float* ucb_values = (float*)&m256_ucb_value; | |
if (ucb_values[0] > max_value) { | |
max_value = ucb_values[0]; | |
max_child = 0 + i; | |
} | |
if (ucb_values[1] > max_value) { | |
max_value = ucb_values[1]; | |
max_child = 1 + i; | |
} | |
if (ucb_values[2] > max_value) { | |
max_value = ucb_values[2]; | |
max_child = 2 + i; | |
} | |
if (ucb_values[3] > max_value) { | |
max_value = ucb_values[3]; | |
max_child = 3 + i; | |
} | |
if (ucb_values[4] > max_value) { | |
max_value = ucb_values[4]; | |
max_child = 4 + i; | |
} | |
if (ucb_values[5] > max_value) { | |
max_value = ucb_values[5]; | |
max_child = 5 + i; | |
} | |
if (ucb_values[6] > max_value) { | |
max_value = ucb_values[6]; | |
max_child = 6 + i; | |
} | |
if (ucb_values[7] > max_value) { | |
max_value = ucb_values[7]; | |
max_child = 7 + i; | |
}*/ | |
/*__m256 vmax = m256_ucb_value; | |
vmax = _mm256_max_ps(vmax, _mm256_castsi256_ps(_mm256_alignr_epi8(_mm256_castps_si256(vmax), _mm256_castps_si256(vmax), 4))); | |
vmax = _mm256_max_ps(vmax, _mm256_castsi256_ps(_mm256_alignr_epi8(_mm256_castps_si256(vmax), _mm256_castps_si256(vmax), 8))); | |
vmax = _mm256_max_ps(vmax, _mm256_castsi256_ps(_mm256_permute2x128_si256(_mm256_castps_si256(vmax), _mm256_castps_si256(vmax), 0x01))); | |
const float value = ((float*)&vmax)[0]; | |
if (value > max_value) { | |
max_value = value; | |
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmax, _CMP_EQ_US); | |
int mask = _mm256_movemask_ps(vcmp); | |
max_child = i + __builtin_ctz(mask); | |
}*/ | |
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS); | |
vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue); | |
vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp)); | |
vnowposition = _mm256_add_epi32(vnowposition, m256i_eight); | |
} | |
const int* maxposition = (int*)&vmaxposition; | |
__m256 vallmax = _mm256_max_ps(vmaxvalue, _mm256_shuffle_ps(vmaxvalue, vmaxvalue, 0xb1)); | |
vallmax = _mm256_max_ps(vallmax, _mm256_shuffle_ps(vallmax, vallmax, 0x4e)); | |
vallmax = _mm256_max_ps(vallmax, _mm256_permute2f128_ps(vallmax, vallmax, 0x01)); | |
__m256 vcmp = _mm256_cmp_ps(vallmax, vmaxvalue, _CMP_EQ_US); | |
int mask = _mm256_movemask_ps(vcmp); | |
max_child = maxposition[__builtin_ctz(mask)]; | |
return max_child; | |
} | |
inline unsigned int SelectMaxUcbChildOrg(const int child_num, float* win, int* move_count, float* nnrate) | |
{ | |
float c = 1.2f; | |
float parent_q = 0.55f; | |
float init_u = 0.66f; | |
float sqrt_sum = 1.5f; | |
float q, u; | |
// UCB値最大の手を求める | |
unsigned int max_child = 0; | |
float max_value = -FLT_MAX; | |
for (size_t i = 0; i < child_num; ++i) { | |
// 未実装 | |
//if (uct_child[i].IsWin()) { | |
// child_win_count++; | |
// // 負けが確定しているノードは選択しない | |
// continue; | |
//} | |
//else if (uct_child[i].IsLose()) { | |
// // 子ノードに一つでも負けがあれば、自ノードを勝ちにできる | |
// if (parent != nullptr) | |
// parent->SetWin(); | |
// // 勝ちが確定しているため、選択する | |
// return i; | |
//} | |
const float win_ = win[i]; | |
const int move_count_ = move_count[i]; | |
if (move_count_ == 0) { | |
// 未探索のノードの価値に、親ノードの価値を使用する | |
q = parent_q; | |
u = init_u; | |
} | |
else { | |
q = (float)(win_ / move_count_); | |
u = sqrt_sum / (1 + move_count_); | |
} | |
const float rate = nnrate[i]; | |
const float ucb_value = q + c * u * rate; | |
if (ucb_value > max_value) { | |
max_value = ucb_value; | |
max_child = i; | |
} | |
} | |
return max_child; | |
} | |
int main() | |
{ | |
/*float v[8] = { 5, 2, 6, 8, 0, 1, 4, 7 }; | |
unsigned int index = _mm256_hmax_index(*(__m256*)v); | |
std::cout << index << std::endl;*/ | |
constexpr int child_num = 16; | |
// 32bitでアライメントされたメモリを初期化 | |
float* win = (float*)_mm_malloc(sizeof(float) * child_num, 32); | |
constexpr float win_init[child_num]{ 0, 1, 1.5, 0, 3, 2.5, 3.5, 4.5, 0, 1, 1.5, 0, 3, 2.5, 3.5, 0 }; | |
std::copy(win_init, win_init + child_num, win); | |
int* move_count = (int*)_mm_malloc(sizeof(int) * child_num, 32); | |
constexpr int move_count_init[child_num]{ 1, 1, 2, 0, 4, 5, 6, 7, 1, 1, 2, 0, 4, 5, 6, 0 }; | |
std::copy(move_count_init, move_count_init + child_num, move_count); | |
float* nnrate = (float*)_mm_malloc(sizeof(float) * child_num, 32); | |
constexpr float nnrate_init[child_num]{ 0.13, 0.18, 0.09, 0.16, 0.07, 0.16, 0.20, 0.00, 0.13, 0.18, 0.09, 0.16, 0.07, 0.16, 0.20, 0 }; | |
std::copy(nnrate_init, nnrate_init + child_num, nnrate); | |
// 測定 | |
auto time_point = std::chrono::high_resolution_clock::now(); | |
unsigned int dummy = 0; | |
for (size_t i = 0; i < 10000000; ++i) | |
dummy += SelectMaxUcbChild(child_num, win, move_count, nnrate); | |
auto duration = std::chrono::high_resolution_clock::now() - time_point; | |
std::cout << dummy << std::endl; | |
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(duration).count() << std::endl; | |
_mm_free(win); | |
_mm_free(move_count); | |
_mm_free(nnrate); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment