Created
March 25, 2020 19:29
-
-
Save sslotin/39a9d0dd2ddf2ebf3d7cf3da113addb4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma GCC optimize("O3") | |
#pragma GCC target("avx2") | |
#include <x86intrin.h> | |
#include <bits/stdc++.h> | |
using namespace std; | |
typedef __m256i reg; | |
const int n = (1<<22), m = (1<<20); | |
int a[n], q[m], results[m]; | |
int binsearch_stl(int x) { | |
return *lower_bound(a, a+n, x); | |
} | |
int binsearch_manual(int x) { | |
int l = 0, r = n; | |
while (l < r) { | |
int t = (l + r) / 2; | |
if (a[t] >= x) | |
r = t; | |
else | |
l = t + 1; | |
} | |
return a[l]; | |
} | |
int binsearch_prefetch(int x) { | |
int l = 0, r = n; | |
while (l < r) { | |
int t = (l + r) / 2; | |
__builtin_prefetch(a + (l + t) / 2); | |
__builtin_prefetch(a + (t + r + 1) / 2); | |
if (a[t] >= x) | |
r = t; | |
else | |
l = t + 1; | |
} | |
return a[l]; | |
} | |
int binsearch_hybrid(int x) { | |
int l = 0, r = n; | |
while (r - l > 12) { | |
int t = (l + r) / 2; | |
if (a[t] >= x) | |
r = t; | |
else | |
l = t + 1; | |
} | |
while (a[l] < x) | |
l++; | |
return a[l]; | |
} | |
void print(reg t) { | |
int t_arr[8]; | |
_mm256_storeu_si256((reg*) t_arr, t); | |
for (int i = 0; i < 8; i++) | |
cerr << t_arr[i] << " "; | |
cerr << endl; | |
} | |
int binsearch_avx2(int x) { | |
int l = 0, r = n-1; | |
reg x_vec = _mm256_set1_epi32(x); | |
const int range[8] = {1, 2, 3, 4, 5, 6, 7, 8}; | |
const reg range_vec = _mm256_loadu_si256((reg*) range); | |
//cerr << "val: " << x << endl; | |
while (r - l > 50) { | |
//cerr << l << " " << r << endl; | |
int step = (r - l + 1) / 9; | |
reg step_vec = _mm256_set1_epi32(step); | |
//print(step_vec); | |
reg t = _mm256_mullo_epi32(range_vec, step_vec); | |
reg y = _mm256_i32gather_epi32(a + l, t, 4); | |
reg c = _mm256_cmpgt_epi32(y, x_vec); | |
//print(t); | |
//print(y); | |
int mask = _mm256_movemask_epi8(c); | |
//cerr << bitset<32>(mask) << " "; | |
int lsb = (ffs(mask)+3)/4; // this may be slow as shit | |
//cerr << lsb << endl; | |
l = (lsb == 0 ? l : l + (lsb-1) * step); | |
r = (lsb == 8 ? r : l + step); | |
} | |
while (a[l] < x) | |
l++; | |
//cerr << "ans: " << a[l] << endl; | |
//cerr << endl; | |
return a[l]; | |
} | |
int b[n+1]; | |
int revbits(int x) { | |
int y = 0; | |
for (int l = 0; (1<<l) < n; l++) | |
y = (y<<1)&(x>>l); | |
return y; | |
} | |
void preproc_power_of_two() { | |
for (int i = 0; i < n; i++) | |
b[revbits(i)] = a[i]; | |
} | |
void preproc(int l = 0, int r = n, int k = 1) { | |
if (l < r) { | |
int t = (l + r) / 2; | |
b[k] = a[t]; | |
preproc(l, t, 2 * k); | |
preproc(t + 1, r, 2 * k + 1); | |
} | |
} | |
int binsearch_oblivious(int x) { | |
int l = 0, r = n, k = 1; | |
while (l < r) { | |
int t = (l + r) / 2; | |
//cerr << t << " " << k << endl; | |
if (b[k] >= x) | |
r = t, k = 2 * k; | |
else | |
l = t + 1, k = 2 * k + 1; // this shouldn't execute if this is last loop | |
//cerr << l << " " << r << " " << k << endl; | |
} | |
k /= 2; | |
return b[k]; | |
} | |
const int block_size = 16; | |
alignas(64) int c[n]; | |
void preproc_btree() { | |
int l = 0, r = n; | |
int k = 0; | |
while (l < r) { | |
int step = (r - l) / (block_size + 1); | |
if (step == 0) { | |
while (c[l] < x) | |
l++, k++; | |
break; | |
} | |
for (int i = 0; i < block_size; i++) { | |
if (c[k+i] >= x) { | |
r = l + i * step + 1; | |
if (i > 0) | |
l = l + (i - 1) * step; | |
k = k * (block_size + 1) + i; | |
break; | |
} | |
} | |
k = k * (block_size + 1) + block_size; | |
l = l + block_size * step; | |
} | |
return c[k]; | |
} | |
int msb(int x) { | |
return 32 - __builtin_clz(x); | |
} | |
const int redir[15] = { | |
1 | |
2 | |
}; | |
template<typename T, int block_height = 4> | |
class CTree { | |
constexpr int block_size = (1<<block_height) - 1; | |
int tree_height, current_height = 0; | |
int position = 0, idx = 0, h_idx = 0, special, special_size; | |
T* memory; | |
Ctree(T* _memory, int array_length) { | |
memory = _memory; | |
tree_height = 32 - __builtin_clz(array_length); | |
tree_height; | |
int semifull_size = (1<<(tree_height-1)) - 1; | |
int remaining = array_length - simufull_size; | |
int critical_height = (tree_height - 1) / block_height * block_heigh; | |
int last_layer_block_size = (1<<(tree_height-critical_height-1)); | |
special = remaining / last_layer_block_size; | |
special_size = remaining % last_layer_block_size; | |
// should be last that has full | |
} | |
int size(int d) { | |
return (1<<d) - 1 | |
} | |
int get_idx() { | |
int base = size(current_height); | |
if (high()) | |
return base + block_size * h_idx; | |
int base_height = size(tree_height - current_height - 1); | |
if (h_idx <= special) | |
return base + h_idx * 2*base_height; | |
else | |
return base | |
+ h_idx * 2*base-height | |
+ (h_idx - special - 1) * base_height | |
+ special_size; | |
} | |
void try_new_block() { | |
if (position >= block_size) { | |
h_idx = h_idx * (1<<d) + (position - 15); | |
current_height += 4; | |
idx = get_idx(); | |
position = 0; | |
} | |
} | |
bool high() { | |
return current_height + 4 < tree_height; | |
} | |
int get_size() { | |
if (high()) | |
return block_size; | |
if (h_idx < special) | |
return height(tree_height - current_height); | |
if (h_idx > special) | |
return height(tree_height - current_height - 1); | |
return special_size; | |
} | |
void left() { | |
position = 2*position + 1; | |
try_new_block(); | |
} | |
void right() { | |
position = 2*position + 2; | |
try_new_block(); | |
} | |
bool end() { | |
return 2*position >= get_size(); | |
} | |
T data() { | |
return memory[idx + position]; | |
} | |
}; | |
template<int d = 4> | |
class CTree { | |
constexpr int bsize = (1<<d) - 1; | |
int n, l, h = 0, k = 0, p = 0; | |
CTree() { | |
// 1, 3, | |
int height = (1<<msb(n)); | |
int rem = (1<<(height-1)) - n; | |
get id of special guy | |
} | |
// l - index of tree horizontally (starting from zero) | |
int s(int h, int l) { | |
if (height - h >= 4) | |
return bsize; | |
int sz = (1<<w) - 1; | |
int rem = n - (1<<w); | |
int step = (1<<w); | |
rem -= (l * step); | |
sz += min((1<<w), max(rem, 0)); | |
return sz; | |
} | |
int get_idx(int h, int l) { | |
// get full | |
if high { | |
// full_layers | |
} | |
if (height - h < 4) { | |
int num_full = // full layers | |
const int sz_full = (1<<w) - 1; | |
cosnt int sz_short (1<<(w-1)) - 1; | |
if (l <= special) { | |
int w = num_full * sz_full; | |
+ (l - num_full) * sz_short; | |
return k + w; | |
} | |
else { | |
// short * l + rem; | |
} | |
} | |
} | |
int change_block(int dir) { | |
int new_l = l * (1<<d) + dir; | |
int new_h = h + d; | |
int new_p = 0; | |
left(p) | |
int get_global_num() { | |
} | |
if (/* last layer */) { | |
if all full | |
if all empty | |
if intersect { | |
} | |
} | |
} | |
int left(int p) { | |
return 2*p + 1; | |
} | |
int right() { | |
return 2*p + 2; | |
} | |
bool end() { | |
return left(p) < s(); | |
} | |
int data() { // dereference | |
return arr[k+p]; | |
} | |
/* | |
1 | |
2 3 | |
4 5 6 7 | |
8 9 10 11 12 13 14 15 | |
*/ | |
}; | |
int build() { | |
return tree<int>(a, n); | |
} | |
int ultrasearch(int x) { | |
auto t = tree.root(); | |
while (!t.end()) { | |
if (*t < x) | |
t.left(); | |
else | |
t.right() | |
} | |
} | |
int preproc | |
int ssearch() { | |
} | |
int binsearch_btree(int x) { | |
int l = 0, r = n; | |
int k = 0; | |
while (l < r) { | |
if ( | |
if (step == 0) { | |
while (c[l] < x) | |
l++, k++; | |
break; | |
} | |
int step = (r - l) / (block_size + 1); | |
for (int i = 0; i < block_size; i++) { | |
if (c[k+i] >= x) { | |
r = l + i * step + 1; | |
if (i > 0) | |
l = l + (i - 1) * step; | |
k = k * (block_size + 1) + i; | |
break; | |
} | |
} | |
k = k * (block_size + 1) + block_size; | |
l = l + block_size * step; | |
} | |
return c[k]; | |
} | |
void timeit(int (*f)(int x), string name) { | |
cerr << "> " << name << endl; | |
clock_t start = clock(); | |
for (int i = 0; i < m; i++) | |
results[i] = f(q[i]); | |
cout << "time: " << double(clock() - start) / CLOCKS_PER_SEC << endl; | |
for (int i = 0; i < 8; i++) | |
cout << results[i] << " "; | |
cout << endl; | |
} | |
int main() { | |
for (int i = 0; i < n; i++) | |
a[i] = rand(); | |
for (int i = 0; i < m; i++) | |
q[i] = rand(); | |
a[0] = RAND_MAX; | |
sort(a, a+n); | |
timeit(binsearch_stl, "stl"); | |
timeit(binsearch_manual, "manual"); | |
//timeit(binsearch_prefetch, "prefetch"); | |
//timeit(binsearch_hybrid); | |
//timeit(binsearch_avx2, "avx2"); | |
preproc(); | |
timeit(binsearch_oblivious, "oblivious"); | |
preproc_btree(); | |
timeit(binsearch_btree, "btree"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment