Created
August 22, 2013 05:42
-
-
Save peterwittek/6303575 to your computer and use it in GitHub Desktop.
Argmin on the rows of a matrix with Thrust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#undef _GLIBCXX_ATOMIC_BUILTINS | |
#undef _GLIBCXX_USE_INT128 | |
#include <thrust/device_vector.h> | |
#include <thrust/reduce.h> | |
#include <thrust/functional.h> | |
#include <thrust/random.h> | |
#include <iostream> | |
// C-style indexing | |
int ci(int row, int column, int nColumns) { | |
return row*nColumns+column; | |
} | |
// Convert a linear index to a row index | |
template <typename T> | |
struct linear_index_to_row_index : public thrust::unary_function<T,T> | |
{ | |
T C; // number of columns | |
__host__ __device__ | |
linear_index_to_row_index(T C) : C(C) {} | |
__host__ __device__ | |
T operator()(T i) | |
{ | |
return i / C; | |
} | |
}; | |
typedef thrust::tuple<int,float> argMinType; | |
struct argMin : public thrust::binary_function<argMinType,argMinType,argMinType> | |
{ | |
__host__ __device__ | |
argMinType operator()(const argMinType& a, const argMinType& b) const | |
{ | |
if (thrust::get<1>(a) < thrust::get<1>(b)){ | |
return a; | |
} else { | |
return b; | |
} | |
} | |
}; | |
thrust::device_vector<argMinType> minsOfRowSpace(thrust::device_vector<float> A, int nRows, int nColumns) { | |
// allocate storage for row argmins and indices | |
thrust::device_vector<argMinType> row_argmins(nRows); | |
thrust::device_vector<int> row_indices(nRows); | |
// compute row argmins by finding argmin values with equal row indices | |
thrust::reduce_by_key | |
(thrust::make_transform_iterator(thrust::counting_iterator<int>(0), linear_index_to_row_index<int>(nColumns)), | |
thrust::make_transform_iterator(thrust::counting_iterator<int>(0), linear_index_to_row_index<int>(nColumns)) + (nRows*nColumns), | |
thrust::make_zip_iterator(thrust::make_tuple(thrust::counting_iterator<int>(0),A.begin())), | |
row_indices.begin(), | |
row_argmins.begin(), | |
thrust::equal_to<int>(), | |
argMin()); | |
return row_argmins; | |
} | |
template <typename T> | |
void printMatrix(thrust::device_vector<T> A, int nRows, int nColumns) { | |
for (size_t i = 0; i < nRows; i++){ | |
for (size_t j = 0; j < nColumns; j++){ | |
std::cout << A[ci(i,j,nColumns)] << " "; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
} | |
int main(void) | |
{ | |
int R_A = 5; // number of rows of A | |
int C_A = 8; // number of columns of A | |
// initialize data | |
thrust::device_vector<float> A(R_A * C_A); | |
for (size_t i = 0; i < R_A; i++) | |
for (size_t j = 0; j < C_A; j++) | |
A[ci(i,j,C_A)]=i+j; | |
A[22]=-1; | |
printMatrix<float>(A, R_A, C_A); | |
thrust::device_vector<argMinType> minsOfA=minsOfRowSpace(A, R_A, C_A); | |
for (size_t i = 0; i < R_A; i++){ | |
for (size_t j = 0; j < 1; j++){ | |
argMinType tmp=minsOfA[ci(i,j,1)]; | |
std::cout << thrust::get<0>(tmp) << " "; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment