Skip to content

Instantly share code, notes, and snippets.

@r-barnes
Last active October 5, 2020 16:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save r-barnes/b8d76be3c7430e450ebe2e2dd95c3ddd to your computer and use it in GitHub Desktop.
Save r-barnes/b8d76be3c7430e450ebe2e2dd95c3ddd to your computer and use it in GitHub Desktop.
#include <cub/cub.cuh>
#include <thrust/device_vector.h>
#include <algorithm>
#include <chrono>
#include <random>
struct SumCountRet {
std::vector<double> sums;
std::vector<uint32_t> counts;
};
__global__ void group_summer_shmem(
const int32_t *const labels,
const float *const weights,
const int num_elements,
const int num_classes,
double *const sums,
uint32_t *const counts
){
constexpr int num_threads = 128;
assert(num_threads==blockDim.x);
//Get shared memory
extern __shared__ int s[];
double *const sums_shmem = (double*)s;
uint32_t *const counts_shmem = (uint32_t*)&sums_shmem[num_threads*num_classes];
double *const my_sums = &sums_shmem [num_classes*threadIdx.x];
uint32_t *const my_counts = &counts_shmem[num_classes*threadIdx.x];
for(int i=threadIdx.x;i<num_threads*num_classes;i+=num_threads){
sums_shmem[i] = 0;
counts_shmem[i] = 0;
}
__syncthreads();
for(int i=blockIdx.x * blockDim.x + threadIdx.x;i<num_elements;i+=gridDim.x*blockDim.x){
const auto l = labels[i];
my_sums[l] += weights[i];
my_counts[l]++;
}
__syncthreads();
__shared__ cub::BlockReduce<double, num_threads>::TempStorage double_temp_storage;
__shared__ cub::BlockReduce<uint32_t, num_threads>::TempStorage uint32_t_temp_storage;
for(int l=0;l<num_classes;l++){
const auto sums_total = cub::BlockReduce<double,num_threads>(double_temp_storage).Reduce(my_sums[l], cub::Sum());
const auto counts_total = cub::BlockReduce<uint32_t,num_threads>(uint32_t_temp_storage).Reduce(my_counts[l], cub::Sum());
if(threadIdx.x==0){
atomicAdd(&sums[l], sums_total);
atomicAdd(&counts[l], counts_total);
}
}
}
__global__ void group_summer_shatomic(
const int32_t *const labels,
const float *const weights,
const int num_elements,
const int num_classes,
double *const sums,
uint32_t *const counts
){
constexpr int num_threads = 128;
assert(num_threads==blockDim.x);
//Get shared memory
extern __shared__ int s[];
double *const sums_shmem = (double*)s;
uint32_t *const counts_shmem = (uint32_t*)&sums_shmem[num_classes];
for(int i=threadIdx.x;i<num_classes;i+=num_threads){
sums_shmem[i] = 0;
counts_shmem[i] = 0;
}
__syncthreads();
for(int i=blockIdx.x * blockDim.x + threadIdx.x;i<num_elements;i+=gridDim.x*blockDim.x){
const auto l = labels[i];
atomicAdd(&sums_shmem[l], (double)weights[i]);
atomicAdd(&counts_shmem[l], 1);
}
__syncthreads();
for(int i=threadIdx.x;i<num_classes;i+=num_threads){
atomicAdd(&sums[i], sums_shmem[i]);
atomicAdd(&counts[i], counts_shmem[i]);
}
}
__global__ void group_summer_global(
const int32_t *const labels,
const float *const weights,
const int num_elements,
const int num_classes,
double *const sums,
uint32_t *const counts
){
for(int i=blockIdx.x * blockDim.x + threadIdx.x;i<num_elements;i+=gridDim.x*blockDim.x){
const auto l = labels[i];
atomicAdd(&sums[l], (double)weights[i]);
atomicAdd(&counts[l], 1);
}
}
SumCountRet group_summer_cpu(
const std::vector<int32_t> &labels,
const std::vector<float> &weights
){
const int num_classes = 1 + *std::max_element(labels.begin(), labels.end());
std::vector<double> sums(num_classes);
std::vector<uint32_t> counts(num_classes);
for(int i=0;i<labels.size();i++){
const auto l = labels[i];
sums[l] += weights[i];
counts[l]++;
}
return {sums, counts};
}
template<class T>
bool vec_nearly_equal(const std::vector<T> &a, const std::vector<T> &b){
if(a.size()!=b.size())
return false;
for(size_t i=0;i<a.size();i++){
if(std::abs(a[i]-b[i])>1e-4)
return false;
}
return true;
}
template<typename Func>
SumCountRet cuda_call(const std::vector<int> &labels, const std::vector<float> &weights, Func func){
const int num_classes = 1 + *std::max_element(labels.begin(), labels.end());
thrust::device_vector<int32_t> d_labels(labels.size());
thrust::device_vector<float> d_weights(labels.size());
thrust::device_vector<double> d_sums(num_classes);
thrust::device_vector<uint32_t> d_counts(num_classes);
thrust::copy(labels.begin(), labels.end(), d_labels.begin());
thrust::copy(weights.begin(), weights.end(), d_weights.begin());
func(d_labels, d_weights, d_sums, d_counts);
std::vector<double> h_sums(num_classes);
std::vector<uint32_t> h_counts(num_classes);
thrust::copy(d_sums.begin(), d_sums.end(), h_sums.begin());
thrust::copy(d_counts.begin(), d_counts.end(), h_counts.begin());
return {h_sums, h_counts};
}
void TestGroupSummer(std::mt19937 &gen, const int N, const int label_max){
std::vector<int32_t> labels(N);
std::vector<float> weights(N);
std::uniform_int_distribution<int> label_dist(0, label_max);
std::uniform_real_distribution<float> weight_dist(0, 5000);
for(int i=0;i<N;i++){
labels[i] = label_dist(gen);
weights[i] = weight_dist(gen);
}
//Shared memory kernel
const auto shmem_ret = cuda_call(labels, weights, [](
thrust::device_vector<int32_t> &d_labels,
thrust::device_vector<float> &d_weights,
thrust::device_vector<double> &d_sums,
thrust::device_vector<uint32_t> &d_counts
){
constexpr int num_threads = 128;
const int num_blocks = (d_labels.size() + num_threads - 1)/num_threads;
const int shmem = num_threads * d_sums.size() * (sizeof(double)+sizeof(uint32_t));
const int num_classes = d_sums.size();
group_summer_shmem<<<num_blocks,num_threads,shmem>>>(
thrust::raw_pointer_cast(d_labels.data()),
thrust::raw_pointer_cast(d_weights.data()),
d_labels.size(),
num_classes,
thrust::raw_pointer_cast(d_sums.data()),
thrust::raw_pointer_cast(d_counts.data())
);
if(cudaGetLastError()!=cudaSuccess){
std::cout<<"Kernel failed to launch!"<<std::endl;
}
if(cudaDeviceSynchronize()!=cudaSuccess){
std::cout<<"Error in kernel!"<<std::endl;
}
});
//Shared atomic memory kernel
const auto shatomic_ret = cuda_call(labels, weights, [](
thrust::device_vector<int32_t> &d_labels,
thrust::device_vector<float> &d_weights,
thrust::device_vector<double> &d_sums,
thrust::device_vector<uint32_t> &d_counts
){
constexpr int num_threads = 128;
const int num_blocks = (d_labels.size() + num_threads - 1)/num_threads;
const int shmem = d_sums.size() * (sizeof(double)+sizeof(uint32_t));
const int num_classes = d_sums.size();
group_summer_shatomic<<<num_blocks,num_threads,shmem>>>(
thrust::raw_pointer_cast(d_labels.data()),
thrust::raw_pointer_cast(d_weights.data()),
d_labels.size(),
num_classes,
thrust::raw_pointer_cast(d_sums.data()),
thrust::raw_pointer_cast(d_counts.data())
);
if(cudaGetLastError()!=cudaSuccess){
std::cout<<"Kernel failed to launch!"<<std::endl;
}
if(cudaDeviceSynchronize()!=cudaSuccess){
std::cout<<"Error in kernel!"<<std::endl;
}
});
//Global memory kernel
const auto global_ret = cuda_call(labels, weights, [](
thrust::device_vector<int32_t> &d_labels,
thrust::device_vector<float> &d_weights,
thrust::device_vector<double> &d_sums,
thrust::device_vector<uint32_t> &d_counts
){
constexpr int num_threads = 128;
const int num_blocks = (d_labels.size() + num_threads - 1)/num_threads;
const int shmem = 0;
const int num_classes = d_sums.size();
group_summer_global<<<num_blocks,num_threads,shmem>>>(
thrust::raw_pointer_cast(d_labels.data()),
thrust::raw_pointer_cast(d_weights.data()),
d_labels.size(),
num_classes,
thrust::raw_pointer_cast(d_sums.data()),
thrust::raw_pointer_cast(d_counts.data())
);
if(cudaGetLastError()!=cudaSuccess){
std::cout<<"Kernel failed to launch!"<<std::endl;
}
if(cudaDeviceSynchronize()!=cudaSuccess){
std::cout<<"Error in kernel!"<<std::endl;
}
});
const auto correct_ret = group_summer_cpu(labels, weights);
std::cout<<"shmem sums good? " <<vec_nearly_equal(shmem_ret.sums,correct_ret.sums)<<std::endl;
std::cout<<"shmem counts good? "<<(shmem_ret.counts==correct_ret.counts)<<std::endl;
std::cout<<"shatomic sums good? " <<vec_nearly_equal(shatomic_ret.sums,correct_ret.sums)<<std::endl;
std::cout<<"shatomic counts good? "<<(shatomic_ret.counts==correct_ret.counts)<<std::endl;
std::cout<<"global sums good? " <<vec_nearly_equal(global_ret.sums,correct_ret.sums)<<std::endl;
std::cout<<"global counts good? "<<(global_ret.counts==correct_ret.counts)<<std::endl;
}
int main(){
std::mt19937 gen;
TestGroupSummer(gen, 10000000, 10);
TestGroupSummer(gen, 10000000, 10);
TestGroupSummer(gen, 10000000, 10);
TestGroupSummer(gen, 10000000, 10);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment