Instantly share code, notes, and snippets.

Embed
What would you like to do?
#include <iostream>
#include <string>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/sequence.h>
void print_array(const thrust::device_vector<int>& val,
const std::string tag,
const int size) {
std::cout << tag;
std::cout << " = {";
for (int i = 0; i < size - 1; i++) {
std::cout << val[i] << ", ";
}
std::cout << val[size - 1] << "}\n";
}
int main() {
const int key[] {1, 1, 2, 2, 2, 3, 3, 1, 1};
const int value[] {1, 1, 1, 1, 1, 1, 1, 1, 1};
const int in_size = sizeof(key) / sizeof(int);
// reduce_by_key test
thrust::device_vector<int> key_in(key, key + in_size);
thrust::device_vector<int> value_in(value, value + in_size);
thrust::device_vector<int> key_out(in_size, 0);
thrust::device_vector<int> value_out(in_size, 0);
auto new_end = thrust::reduce_by_key(key_in.begin(),
key_in.end(),
value_in.begin(),
key_out.begin(),
value_out.begin());
int out_size = new_end.second - value_out.begin();
std::cout << "reduce_by_key test\n";
std::cout << "input:\n";
print_array(key_in, "key ", in_size);
print_array(value_in, "value", in_size);
std::cout << "output:\n";
print_array(key_out, "key ", out_size);
print_array(value_out, "value", out_size);
std::cout << "\n";
// histogram
thrust::device_vector<int> uni_vect(in_size, 1);
thrust::device_vector<int> num_in_bins(in_size);
thrust::sort(key_in.begin(), key_in.end());
new_end = thrust::reduce_by_key(key_in.begin(),
key_in.end(),
uni_vect.begin(),
key_out.begin(),
num_in_bins.begin());
out_size = new_end.second - num_in_bins.begin();
std::cout << "histogram test\n";
std::cout << "input:\n";
print_array(key_in, "key", in_size);
std::cout << "output:\n";
print_array(key_out, "key ", out_size);
print_array(num_in_bins, "number in bin", out_size);
std::cout << "\n";
// get histogram pointer
thrust::device_vector<int> pointer(out_size + 1, 0);
thrust::inclusive_scan(num_in_bins.begin(),
num_in_bins.end(),
pointer.begin() + 1);
std::cout << "histogram pointer\n";
print_array(pointer, "pointer", pointer.size());
std::cout << "\n";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment