Last active
September 23, 2020 22:54
-
-
Save Jacajack/a4efc1ba6e46e82dc413ae14710633ac to your computer and use it in GitHub Desktop.
gotta go fast
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <algorithm> | |
#include <numeric> | |
#include <cinttypes> | |
#include <random> | |
#include <functional> | |
#include <chrono> | |
#include <limits> | |
void radix_sort(int *data, size_t size) | |
{ | |
constexpr int digit_size = 8; | |
constexpr int bucket_count = 1 << digit_size; | |
constexpr int digit_mask = bucket_count - 1; | |
unsigned int buckets[bucket_count]; | |
int *tmp = new int[size]; | |
int *input = data; | |
int *output = tmp; | |
// Flip sign bit to handle signed ints | |
for (int i = 0; i < size; i++) | |
data[i] ^= (1 << (sizeof(int) * 8 - 1)); | |
for (int iter = 0; iter < sizeof(int) * 8 / digit_size; iter++) | |
{ | |
// Empty buckets | |
std::fill(buckets, buckets + bucket_count, 0); | |
// Count digits | |
for (int i = 0; i < size; i++) | |
buckets[(input[i] >> (iter * digit_size)) & digit_mask]++; | |
// Compute prefix sums | |
std::partial_sum(buckets, buckets + bucket_count, buckets, std::plus<int>()); | |
// Unpack values to output array | |
for (int i = size - 1; i >= 0; i--) | |
{ | |
int digit = (input[i] >> (iter * digit_size)) & digit_mask; | |
int index = --buckets[digit]; | |
output[index] = input[i]; | |
} | |
std::swap(input, output); | |
} | |
// Flip sign bits again | |
for (int i = 0; i < size; i++) | |
data[i] ^= (1 << (sizeof(int) * 8 - 1)); | |
delete[] tmp; | |
} | |
int main(int argc, char *argv[]) | |
{ | |
int N = 50000000; | |
std::cout << "Sorting " << N << " elements!" << std::endl; | |
// Random generator | |
std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(), std::numeric_limits<int>::max()); | |
std::random_device rd; | |
std::mt19937 rng(rd()); | |
std::vector<int> data(N, 0); | |
for (auto &v : data) | |
v = dist(rng); | |
std::vector<int> radix_data = data; | |
std::vector<int> qsort_data = data; | |
auto t1 = std::chrono::system_clock::now(); | |
radix_sort(&radix_data[0], radix_data.size()); | |
auto t2 = std::chrono::system_clock::now(); | |
if (!std::is_sorted(radix_data.begin(), radix_data.end())) | |
std::cout << "Radix sort failed!" << std::endl; | |
else | |
std::cout << "Radix sort done!" << std::endl; | |
auto t3 = std::chrono::system_clock::now(); | |
std::sort(qsort_data.begin(), qsort_data.end()); | |
auto t4 = std::chrono::system_clock::now(); | |
if (!std::is_sorted(qsort_data.begin(), qsort_data.end())) | |
std::cout << "std::sort() failed!" << std::endl; | |
else | |
std::cout << "std::sort() done!" << std::endl; | |
auto radix_elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1); | |
auto qsort_elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(t4 - t3); | |
// Let's pretend | |
volatile int x; | |
for (int i : radix_data) x = i; | |
for (int i : qsort_data) x = i; | |
std::cout << "radix_sort: " << radix_elapsed.count() << "ms" << std::endl; | |
std::cout << " std::sort: " << qsort_elapsed.count() << "ms" << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment