Skip to content

Instantly share code, notes, and snippets.

@peterwittek
Created August 22, 2013 05:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save peterwittek/6303575 to your computer and use it in GitHub Desktop.
Save peterwittek/6303575 to your computer and use it in GitHub Desktop.
Argmin on the rows of a matrix with Thrust
#undef _GLIBCXX_ATOMIC_BUILTINS
#undef _GLIBCXX_USE_INT128
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/functional.h>
#include <thrust/random.h>
#include <iostream>
// C-style indexing
int ci(int row, int column, int nColumns) {
return row*nColumns+column;
}
// Convert a linear index to a row index
template <typename T>
struct linear_index_to_row_index : public thrust::unary_function<T,T>
{
T C; // number of columns
__host__ __device__
linear_index_to_row_index(T C) : C(C) {}
__host__ __device__
T operator()(T i)
{
return i / C;
}
};
typedef thrust::tuple<int,float> argMinType;
struct argMin : public thrust::binary_function<argMinType,argMinType,argMinType>
{
__host__ __device__
argMinType operator()(const argMinType& a, const argMinType& b) const
{
if (thrust::get<1>(a) < thrust::get<1>(b)){
return a;
} else {
return b;
}
}
};
thrust::device_vector<argMinType> minsOfRowSpace(thrust::device_vector<float> A, int nRows, int nColumns) {
// allocate storage for row argmins and indices
thrust::device_vector<argMinType> row_argmins(nRows);
thrust::device_vector<int> row_indices(nRows);
// compute row argmins by finding argmin values with equal row indices
thrust::reduce_by_key
(thrust::make_transform_iterator(thrust::counting_iterator<int>(0), linear_index_to_row_index<int>(nColumns)),
thrust::make_transform_iterator(thrust::counting_iterator<int>(0), linear_index_to_row_index<int>(nColumns)) + (nRows*nColumns),
thrust::make_zip_iterator(thrust::make_tuple(thrust::counting_iterator<int>(0),A.begin())),
row_indices.begin(),
row_argmins.begin(),
thrust::equal_to<int>(),
argMin());
return row_argmins;
}
template <typename T>
void printMatrix(thrust::device_vector<T> A, int nRows, int nColumns) {
for (size_t i = 0; i < nRows; i++){
for (size_t j = 0; j < nColumns; j++){
std::cout << A[ci(i,j,nColumns)] << " ";
}
std::cout << "\n";
}
std::cout << "\n";
}
int main(void)
{
int R_A = 5; // number of rows of A
int C_A = 8; // number of columns of A
// initialize data
thrust::device_vector<float> A(R_A * C_A);
for (size_t i = 0; i < R_A; i++)
for (size_t j = 0; j < C_A; j++)
A[ci(i,j,C_A)]=i+j;
A[22]=-1;
printMatrix<float>(A, R_A, C_A);
thrust::device_vector<argMinType> minsOfA=minsOfRowSpace(A, R_A, C_A);
for (size_t i = 0; i < R_A; i++){
for (size_t j = 0; j < 1; j++){
argMinType tmp=minsOfA[ci(i,j,1)];
std::cout << thrust::get<0>(tmp) << " ";
}
std::cout << "\n";
}
std::cout << "\n";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment