Skip to content

Instantly share code, notes, and snippets.

@TadaoYamaoka
Last active March 27, 2021 05:49
Show Gist options
  • Save TadaoYamaoka/46a0c0baed7bcafbb4963a6056df3426 to your computer and use it in GitHub Desktop.
Save TadaoYamaoka/46a0c0baed7bcafbb4963a6056df3426 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <cfloat>
#include <chrono>
#include <immintrin.h>
inline unsigned int __builtin_ctz(unsigned int x) { unsigned long r; _BitScanForward(&r, x); return r; }
const __m256i m256i_zero{};
const __m256i m256i_one = _mm256_set1_epi32(1);
const __m256 m256_one = _mm256_set1_ps(1);
const __m256i m256i_eight = _mm256_set1_epi32(8);
inline unsigned int SelectMaxUcbChild(const int child_num, float* win, int* move_count, float* nnrate)
{
float c = 1.2f;
float parent_q = 0.55f;
float init_u = 0.66f;
float sqrt_sum = 1.5f;
__m256 m256_c = _mm256_broadcast_ss(&c);
__m256 m256_parent_q = _mm256_broadcast_ss(&parent_q);
__m256 m256_init_u = _mm256_broadcast_ss(&init_u);
__m256 m256_sqrt_sum = _mm256_broadcast_ss(&sqrt_sum);
// UCB値最大の手を求める
unsigned int max_child = 0;
//float max_value = -FLT_MAX;
__m256 vmaxvalue = _mm256_set1_ps(-FLT_MAX);
__m256i vnowposition = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256i vmaxposition = vnowposition;
for (size_t i = 0; i < child_num; i += 8) {
if (i + 8 > child_num) {
// 残り8未満
__m256i mask_rest;
switch (child_num - i) {
case 1:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
break;
case 2:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
break;
case 3:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
break;
case 4:
mask_rest = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
break;
case 5:
mask_rest = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
break;
case 6:
mask_rest = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
break;
case 7:
mask_rest = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
break;
default:
// unreachable
mask_rest = _mm256_set1_epi32(0);
break;
}
__m256i m256i_move_count = _mm256_maskload_epi32(move_count + i, mask_rest);
__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero);
// q = (float)(win / move_count);
__m256 m256_win = _mm256_maskload_ps(win + i, mask_rest);
__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count);
__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count);
__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask));
// u = sqrt_sum / (1 + move_count);
__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one);
__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1);
__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask));
__m256 m256_rate = _mm256_maskload_ps(nnrate + i, mask_rest);
//const float ucb_value = q + c * u * rate;
__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u);
m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate);
m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value);
// mask
m256_ucb_value = _mm256_and_ps(m256_ucb_value, _mm256_castsi256_ps(mask_rest));
// find max
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS);
vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue);
vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp));
vnowposition = _mm256_add_epi32(vnowposition, m256i_eight);
break;
}
// 未実装
//if (uct_child[i].IsWin()) {
// child_win_count++;
// // 負けが確定しているノードは選択しない
// continue;
//}
//else if (uct_child[i].IsLose()) {
// // 子ノードに一つでも負けがあれば、自ノードを勝ちにできる
// if (parent != nullptr)
// parent->SetWin();
// // 勝ちが確定しているため、選択する
// return i;
//}
//const WinType win = uct_child[i].win;
//const int move_count = uct_child[i].move_count;
//
//if (move_count == 0) {
__m256i m256i_move_count = _mm256_load_si256((__m256i*)(move_count + i));
__m256i mask = _mm256_cmpgt_epi32(m256i_move_count, m256i_zero);
// // 未探索のノードの価値に、親ノードの価値を使用する
// q = parent_q;
// u = init_u;
// --> 下記のelseの計算結果と合わせて_mm256_blendv_psで設定する
//}
//else {
// q = (float)(win / move_count);
__m256 m256_win = _mm256_load_ps(win + i);
__m256 m256_move_count = _mm256_cvtepi32_ps(m256i_move_count);
__m256 m256_q_tmp = _mm256_div_ps(m256_win, m256_move_count);
// ニュートン-ラフソン法で割り算を高速化
//__m256 rc = _mm256_rcp_ps(m256_move_count);
//__m256 rcip = _mm256_sub_ps(_mm256_add_ps(rc, rc), _mm256_mul_ps(rc, _mm256_mul_ps(rc, m256_move_count)));
//__m256 m256_q_tmp = _mm256_mul_ps(m256_win, rcip);
__m256 m256_q = _mm256_blendv_ps(m256_parent_q, m256_q_tmp, _mm256_castsi256_ps(mask));
// u = sqrt_sum / (1 + move_count);
__m256 m256_move_count_plus1 = _mm256_add_ps(m256_move_count, m256_one);
__m256 m256_u_tmp = _mm256_div_ps(m256_sqrt_sum, m256_move_count_plus1);
// ニュートン-ラフソン法で割り算を高速化
//rc = _mm256_rcp_ps(m256_move_count_plus1);
//rcip = _mm256_sub_ps(_mm256_add_ps(rc, rc), _mm256_mul_ps(rc, _mm256_mul_ps(rc, m256_move_count_plus1)));
//__m256 m256_u_tmp = _mm256_mul_ps(m256_sqrt_sum, rcip);
__m256 m256_u = _mm256_blendv_ps(m256_init_u, m256_u_tmp, _mm256_castsi256_ps(mask));
//}
//const float rate = uct_child[i].nnrate;
__m256 m256_rate = _mm256_load_ps(nnrate + i);
//const float ucb_value = q + c * u * rate;
__m256 m256_ucb_value = _mm256_mul_ps(m256_c, m256_u);
m256_ucb_value = _mm256_mul_ps(m256_ucb_value, m256_rate);
m256_ucb_value = _mm256_add_ps(m256_q, m256_ucb_value);
/*const float* ucb_values = (float*)&m256_ucb_value;
if (ucb_values[0] > max_value) {
max_value = ucb_values[0];
max_child = 0 + i;
}
if (ucb_values[1] > max_value) {
max_value = ucb_values[1];
max_child = 1 + i;
}
if (ucb_values[2] > max_value) {
max_value = ucb_values[2];
max_child = 2 + i;
}
if (ucb_values[3] > max_value) {
max_value = ucb_values[3];
max_child = 3 + i;
}
if (ucb_values[4] > max_value) {
max_value = ucb_values[4];
max_child = 4 + i;
}
if (ucb_values[5] > max_value) {
max_value = ucb_values[5];
max_child = 5 + i;
}
if (ucb_values[6] > max_value) {
max_value = ucb_values[6];
max_child = 6 + i;
}
if (ucb_values[7] > max_value) {
max_value = ucb_values[7];
max_child = 7 + i;
}*/
/*__m256 vmax = m256_ucb_value;
vmax = _mm256_max_ps(vmax, _mm256_castsi256_ps(_mm256_alignr_epi8(_mm256_castps_si256(vmax), _mm256_castps_si256(vmax), 4)));
vmax = _mm256_max_ps(vmax, _mm256_castsi256_ps(_mm256_alignr_epi8(_mm256_castps_si256(vmax), _mm256_castps_si256(vmax), 8)));
vmax = _mm256_max_ps(vmax, _mm256_castsi256_ps(_mm256_permute2x128_si256(_mm256_castps_si256(vmax), _mm256_castps_si256(vmax), 0x01)));
const float value = ((float*)&vmax)[0];
if (value > max_value) {
max_value = value;
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmax, _CMP_EQ_US);
int mask = _mm256_movemask_ps(vcmp);
max_child = i + __builtin_ctz(mask);
}*/
__m256 vcmp = _mm256_cmp_ps(m256_ucb_value, vmaxvalue, _CMP_GT_OS);
vmaxvalue = _mm256_max_ps(m256_ucb_value, vmaxvalue);
vmaxposition = _mm256_blendv_epi8(vmaxposition, vnowposition, _mm256_castps_si256(vcmp));
vnowposition = _mm256_add_epi32(vnowposition, m256i_eight);
}
const int* maxposition = (int*)&vmaxposition;
__m256 vallmax = _mm256_max_ps(vmaxvalue, _mm256_shuffle_ps(vmaxvalue, vmaxvalue, 0xb1));
vallmax = _mm256_max_ps(vallmax, _mm256_shuffle_ps(vallmax, vallmax, 0x4e));
vallmax = _mm256_max_ps(vallmax, _mm256_permute2f128_ps(vallmax, vallmax, 0x01));
__m256 vcmp = _mm256_cmp_ps(vallmax, vmaxvalue, _CMP_EQ_US);
int mask = _mm256_movemask_ps(vcmp);
max_child = maxposition[__builtin_ctz(mask)];
return max_child;
}
inline unsigned int SelectMaxUcbChildOrg(const int child_num, float* win, int* move_count, float* nnrate)
{
float c = 1.2f;
float parent_q = 0.55f;
float init_u = 0.66f;
float sqrt_sum = 1.5f;
float q, u;
// UCB値最大の手を求める
unsigned int max_child = 0;
float max_value = -FLT_MAX;
for (size_t i = 0; i < child_num; ++i) {
// 未実装
//if (uct_child[i].IsWin()) {
// child_win_count++;
// // 負けが確定しているノードは選択しない
// continue;
//}
//else if (uct_child[i].IsLose()) {
// // 子ノードに一つでも負けがあれば、自ノードを勝ちにできる
// if (parent != nullptr)
// parent->SetWin();
// // 勝ちが確定しているため、選択する
// return i;
//}
const float win_ = win[i];
const int move_count_ = move_count[i];
if (move_count_ == 0) {
// 未探索のノードの価値に、親ノードの価値を使用する
q = parent_q;
u = init_u;
}
else {
q = (float)(win_ / move_count_);
u = sqrt_sum / (1 + move_count_);
}
const float rate = nnrate[i];
const float ucb_value = q + c * u * rate;
if (ucb_value > max_value) {
max_value = ucb_value;
max_child = i;
}
}
return max_child;
}
int main()
{
/*float v[8] = { 5, 2, 6, 8, 0, 1, 4, 7 };
unsigned int index = _mm256_hmax_index(*(__m256*)v);
std::cout << index << std::endl;*/
constexpr int child_num = 16;
// 32bitでアライメントされたメモリを初期化
float* win = (float*)_mm_malloc(sizeof(float) * child_num, 32);
constexpr float win_init[child_num]{ 0, 1, 1.5, 0, 3, 2.5, 3.5, 4.5, 0, 1, 1.5, 0, 3, 2.5, 3.5, 0 };
std::copy(win_init, win_init + child_num, win);
int* move_count = (int*)_mm_malloc(sizeof(int) * child_num, 32);
constexpr int move_count_init[child_num]{ 1, 1, 2, 0, 4, 5, 6, 7, 1, 1, 2, 0, 4, 5, 6, 0 };
std::copy(move_count_init, move_count_init + child_num, move_count);
float* nnrate = (float*)_mm_malloc(sizeof(float) * child_num, 32);
constexpr float nnrate_init[child_num]{ 0.13, 0.18, 0.09, 0.16, 0.07, 0.16, 0.20, 0.00, 0.13, 0.18, 0.09, 0.16, 0.07, 0.16, 0.20, 0 };
std::copy(nnrate_init, nnrate_init + child_num, nnrate);
// 測定
auto time_point = std::chrono::high_resolution_clock::now();
unsigned int dummy = 0;
for (size_t i = 0; i < 10000000; ++i)
dummy += SelectMaxUcbChild(child_num, win, move_count, nnrate);
auto duration = std::chrono::high_resolution_clock::now() - time_point;
std::cout << dummy << std::endl;
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(duration).count() << std::endl;
_mm_free(win);
_mm_free(move_count);
_mm_free(nnrate);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment