Skip to content

Instantly share code, notes, and snippets.

@louchenyao
Last active November 19, 2020 06:08
Show Gist options
  • Save louchenyao/c7f31255608b47c0687f3f82ec36ccec to your computer and use it in GitHub Desktop.
Save louchenyao/c7f31255608b47c0687f3f82ec36ccec to your computer and use it in GitHub Desktop.
#include <algorithm>
#include <cstdio>
#include <cstdlib>
template <typename T>
void cmp(T &a, T &b) {
if (a > b) {
std::swap(a, b);
}
}
constexpr uint32_t div2_roundup_to_power_of_two(uint32_t v) {
v = (v+1)/2;
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
template <typename T, int BEGIN, int END, bool REV>
struct Merge {
static void merge(T *v) {
const int N = END - BEGIN;
if (N <= 1) return;
if (!REV) {
const int RIGHT = div2_roundup_to_power_of_two(N);
const int LEFT = N - RIGHT;
#pragma unroll
for (int i = END - 1; i >= END - LEFT; i--) {
cmp(v[i-RIGHT], v[i]);
}
Merge<T, BEGIN, BEGIN + LEFT, REV>::merge(v);
Merge<T, BEGIN+LEFT, END, REV>::merge(v);
} else {
const int LEFT = div2_roundup_to_power_of_two(N);
const int RIGHT = N - LEFT;
#pragma unroll
for (int i = BEGIN; i < BEGIN + RIGHT; i++) {
cmp(v[i+LEFT], v[i]);
}
Merge<T, BEGIN, BEGIN + LEFT, REV>::merge(v);
Merge<T, BEGIN+LEFT, END, REV>::merge(v);
}
}
};
// this constructs a bitonic sorter
template <typename T, int BEGIN, int END, bool REV>
struct ThreadSort {
static void sort(T *v) {
const int N = END - BEGIN;
if (N <= 1) return;
const int LEFT = N / 2;
const int RIGHT = N - LEFT;
ThreadSort<T, BEGIN, BEGIN+LEFT, false>::sort(v);
ThreadSort<T, BEGIN+LEFT, END, true>::sort(v);
Merge<T, BEGIN, END, REV>::merge(v);
}
};
template <int N>
void test() {
for (int CASE = 0; CASE < (1 << N); CASE++) {
int a[N];
for (int i = 0; i < N; ++i) {
a[i] = (CASE >> i) & 1;
}
ThreadSort<int, 0, N, false>::sort(a);
for (int i = 1; i < N; ++i) {
if (a[i] < a[i - 1]) {
printf("N = %d, CASE = %d, ERROR\n", N, CASE);
exit(1);
}
}
}
}
int main() {
// int v[5] = {1, 7, 3, 2, 4};
// ThreadSort<int, 0, 5, false>::sort(v);
// for (int i = 0; i < 5; i++) {
// printf("%d ", v[i]);
// }
// printf("\n");
test<4>();
test<5>();
test<6>();
test<7>();
test<8>();
test<9>();
test<10>();
test<11>();
test<12>();
test<13>();
test<14>();
test<15>();
test<16>();
printf("PASS\n");
return 0;
}
#include <algorithm>
#include <cstdio>
#include <cstdlib>
template <typename K, typename V, int N>
void cmp(K (&k)[N], V (&v)[N], int i, int j) {
if (k[i] > k[j]) {
std::swap(k[i], k[j]);
std::swap(v[i], v[j]);
}
}
constexpr uint32_t div2_roundup_to_power_of_two(uint32_t v) {
v = (v+1)/2;
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
template <typename K, typename V, int N, int BEGIN, int END, bool REV>
struct Merge {
static void merge(K (&k)[N], V (&v)[N]) {
const int LEN = END - BEGIN;
if (LEN <= 1) return;
if (!REV) {
const int RIGHT = div2_roundup_to_power_of_two(LEN);
const int LEFT = LEN - RIGHT;
#pragma unroll
for (int i = END - 1; i >= END - LEFT; i--) {
cmp(k, v, i-RIGHT, i);
}
Merge<K, V, N, BEGIN, BEGIN + LEFT, REV>::merge(k, v);
Merge<K, V, N, BEGIN+LEFT, END, REV>::merge(k, v);
} else {
const int LEFT = div2_roundup_to_power_of_two(LEN);
const int RIGHT = LEN - LEFT;
#pragma unroll
for (int i = BEGIN; i < BEGIN + RIGHT; i++) {
cmp(k, v, i+LEFT, i);
}
Merge<K, V, N, BEGIN, BEGIN + LEFT, REV>::merge(k, v);
Merge<K, V, N, BEGIN+LEFT, END, REV>::merge(k, v);
}
}
};
// this constructs a bitonic sorter
template <typename K, typename V, int N, int BEGIN, int END, bool REV>
struct ThreadSort {
static void sort(K (&k)[N], V (&v)[N]) {
const int LEN = END - BEGIN;
if (LEN <= 1) return;
const int LEFT = LEN / 2;
const int RIGHT = LEN - LEFT;
ThreadSort<K, V, N, BEGIN, BEGIN+LEFT, false>::sort(k, v);
ThreadSort<K, V, N, BEGIN+LEFT, END, true>::sort(k, v);
Merge<K, V, N, BEGIN, END, REV>::merge(k, v);
}
};
template <int N>
void test() {
for (int CASE = 0; CASE < (1 << N); CASE++) {
int k[N];
int v[N];
for (int i = 0; i < N; ++i) {
k[i] = (CASE >> i) & 1;
v[i] = i;
}
ThreadSort<int, int, N, 0, N, false>::sort(k, v);
for (int i = 1; i < N; ++i) {
if (k[i] < k[i - 1]) {
printf("N = %d, CASE = %d, ERROR\n", N, CASE);
exit(1);
}
}
}
}
int main() {
int k[5] = {1, 7, 3, 2, 4};
char v[5] = {'a', 'b', 'c', 'd', 'e'};
ThreadSort<int, char, 5, 0, 5, false>::sort(k, v);
for (int i = 0; i < 5; i++) {
printf("(%d %c) ", k[i], v[i]);
}
printf("\n");
test<4>();
test<5>();
test<6>();
test<7>();
test<8>();
test<9>();
test<10>();
test<11>();
test<12>();
test<13>();
test<14>();
test<15>();
test<16>();
printf("PASS\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment