Skip to content

Instantly share code, notes, and snippets.

@nem0
Created February 10, 2021 21:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nem0/2009e821026de1e5f102bb0dc15203bc to your computer and use it in GitHub Desktop.
Save nem0/2009e821026de1e5f102bb0dc15203bc to your computer and use it in GitHub Desktop.
radix sort
void radixSort2(u64* _keys, u64* _values, int size) {
enum { WORKERS = 10 };
PROFILE_FUNCTION();
profiler::pushInt("count", size);
if (size == 0) return;
Array<u64>& tmp_mem = allocRadixTmp();
u64* keys = _keys;
u64* values = _values;
u64* tmp_keys = nullptr;
u64* tmp_values = nullptr;
Histogram histogram;
u16 shift = 0;
constexpr u64 BIT_MASK = (1 << 11) - 1;
tmp_mem.resize(size * 2);
tmp_keys = tmp_mem.begin();
tmp_values = &tmp_mem[size];
for (int pass = 0; pass < 6; ++pass) {
u32 histograms[WORKERS][1 << 11];
const u32 step = (size + WORKERS - 1) / WORKERS;
jobs::forEach(WORKERS, 1, [&](i32 idx, i32) {
PROFILE_BLOCK("histo2");
const u32 from = idx * step;
const u32 to = minimum(from + step, size);
u32* local_histogram = histograms[idx];
memset(local_histogram, 0, sizeof(histograms[idx]));
for (u32 i = from; i < to; ++i) {
const u64 key = keys[i];
const u64 index = (key >> shift) & BIT_MASK;
++local_histogram[index];
}
});
u32 offset = 0;
for (int i = 0; i < 1 << 11; ++i) {
for (u32 j = 0; j < lengthOf(histograms); ++j) {
const u32 count = histograms[j][i];
histograms[j][i] = offset;
offset += count;
}
}
jobs::forEach(WORKERS, 1, [&](i32 idx, i32) {
PROFILE_BLOCK("scatter multi");
const u32 from = idx * step;
const u32 to = minimum(from + step, size);
u32* local_histogram = histograms[idx];
profiler::pushInt("count", to - from);
for (u32 i = from; i < to; ++i) {
const u64 key = keys[i];
const u64 index = (key >> shift) & BIT_MASK;
const u32 dest = local_histogram[index]++;
ASSERT(dest < (u32)size);
tmp_keys[dest] = key;
tmp_values[dest] = values[i];
}
});
swap(tmp_keys, keys);
swap(tmp_values, values);
shift += Histogram::BITS;
}
releaseRadixTmp(tmp_mem);
}
void radixSort3(u64* _keys, u64* _values, int size) {
PROFILE_FUNCTION();
profiler::pushInt("count", size);
if (size == 0) return;
Array<u64>& tmp_mem = allocRadixTmp();
u64* keys = _keys;
u64* values = _values;
u64* tmp_keys = nullptr;
u64* tmp_values = nullptr;
Histogram histogram;
u16 shift = 0;
for (int pass = 0; pass < 6; ++pass) {
histogram.compute(keys, values, size, shift);
// if (histogram.m_sorted) {
// if (pass & 1) {
// memcpy(_keys, tmp_mem.begin(), tmp_mem.byte_size() / 2);
// memcpy(_values, &tmp_mem[size], tmp_mem.byte_size() / 2);
// }
// return;
//}
if (!tmp_keys) {
tmp_mem.resize(size * 2);
tmp_keys = tmp_mem.begin();
tmp_values = &tmp_mem[size];
}
u32 offset = 0;
for (int i = 0; i < Histogram::SIZE; ++i) {
const u32 count = histogram.m_histogram[i];
histogram.m_histogram[i] = offset;
offset += count;
}
u32 foo[Histogram::SIZE];
memcpy(foo, histogram.m_histogram + 1, sizeof(histogram.m_histogram) - sizeof(histogram.m_histogram[0]));
foo[Histogram::SIZE - 1] = size;
auto back_pass = [&]() {
PROFILE_BLOCK("back_pass");
profiler::pushInt("pass", pass);
profiler::pushInt("count", size - size / 2);
u64* LUMIX_RESTRICT k = keys;
u64* LUMIX_RESTRICT v = values;
u64* LUMIX_RESTRICT tk = tmp_keys;
u64* LUMIX_RESTRICT tv = tmp_values;
u32* LUMIX_RESTRICT h = foo;
for (int i = size - 1; i >= size / 2; --i) {
const u64 key = k[i];
const u16 index = (key >> shift) & Histogram::BIT_MASK;
const u32 dest = --h[index];
tk[dest] = key;
tv[dest] = v[i];
}
};
jobs::SignalHandle signal = jobs::INVALID_HANDLE;
jobs::run(
&back_pass,
[](void* data) {
auto f = (decltype(back_pass)*)data;
(*f)();
},
&signal);
for (int i = 0; i < size / 2; ++i) {
const u64 key = keys[i];
const u16 index = (key >> shift) & Histogram::BIT_MASK;
const u32 dest = histogram.m_histogram[index]++;
tmp_keys[dest] = key;
tmp_values[dest] = values[i];
}
jobs::wait(signal);
swap(tmp_keys, keys);
swap(tmp_values, values);
shift += Histogram::BITS;
}
releaseRadixTmp(tmp_mem);
}
void radixSort(u64* _keys, u64* _values, int size) {
PROFILE_FUNCTION();
profiler::pushInt("count", size);
if (size == 0) return;
Array<u64>& tmp_mem = allocRadixTmp();
u64* keys = _keys;
u64* values = _values;
u64* tmp_keys = nullptr;
u64* tmp_values = nullptr;
Histogram histogram;
u16 shift = 0;
for (int pass = 0; pass < 6; ++pass) {
histogram.compute(keys, values, size, shift);
PROFILE_BLOCK("radix sort pass");
profiler::pushInt("count", size);
// if (histogram.m_sorted) {
// if (pass & 1) {
// memcpy(_keys, tmp_mem.begin(), tmp_mem.byte_size() / 2);
// memcpy(_values, &tmp_mem[size], tmp_mem.byte_size() / 2);
// }
// return;
//}
if (!tmp_keys) {
tmp_mem.resize(size * 2);
tmp_keys = tmp_mem.begin();
tmp_values = &tmp_mem[size];
}
u32 offset = 0;
for (int i = 0; i < Histogram::SIZE; ++i) {
const u32 count = histogram.m_histogram[i];
histogram.m_histogram[i] = offset;
offset += count;
}
{
PROFILE_BLOCK("1sthalf");
profiler::pushInt("count", size / 2);
for (int i = 0; i < size / 2; ++i) {
const u64 key = keys[i];
const u16 index = (key >> shift) & Histogram::BIT_MASK;
const u32 dest = histogram.m_histogram[index]++;
tmp_keys[dest] = key;
tmp_values[dest] = values[i];
}
}
for (int i = size / 2; i < size; ++i) {
const u64 key = keys[i];
const u16 index = (key >> shift) & Histogram::BIT_MASK;
const u32 dest = histogram.m_histogram[index]++;
tmp_keys[dest] = key;
tmp_values[dest] = values[i];
}
swap(tmp_keys, keys);
swap(tmp_values, values);
shift += Histogram::BITS;
}
releaseRadixTmp(tmp_mem);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment