Created
April 6, 2012 06:11
-
-
Save jaredhoberock/2317505 to your computer and use it in GitHub Desktop.
How to build merge_by_key using thrust::merge and thrust::zip_iterator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <thrust/pair.h> | |
#include <thrust/tuple.h> | |
#include <thrust/iterator/zip_iterator.h> | |
#include <thrust/functional.h> | |
#include <thrust/merge.h> | |
#include <iostream> | |
#include <algorithm> | |
// compares the 0th elements of two tuples | |
template<typename Compare> | |
struct compare_first | |
{ | |
compare_first(Compare comp) | |
: comp(comp) | |
{} | |
template<typename Tuple1, typename Tuple2> | |
__host__ __device__ | |
bool operator()(const Tuple1 &lhs, const Tuple2 &rhs) | |
{ | |
return comp(thrust::get<0>(lhs), thrust::get<0>(rhs)); | |
} | |
Compare comp; | |
}; | |
template<typename Compare> | |
compare_first<Compare> make_compare_first(Compare comp) | |
{ | |
return compare_first<Compare>(comp); | |
} | |
template<typename ForwardIterator1, | |
typename ForwardIterator2, | |
typename ForwardIterator3, | |
typename ForwardIterator4, | |
typename OutputIterator1, | |
typename OutputIterator2, | |
typename Compare> | |
thrust::pair<OutputIterator1,OutputIterator2> | |
merge_by_key(ForwardIterator1 keys_first1, | |
ForwardIterator1 keys_last1, | |
ForwardIterator2 values_first1, | |
ForwardIterator3 keys_first2, | |
ForwardIterator3 keys_last2, | |
ForwardIterator4 values_first2, | |
OutputIterator1 keys_result, | |
OutputIterator2 values_result, | |
Compare comp) | |
{ | |
typedef thrust::zip_iterator<thrust::tuple<OutputIterator1,OutputIterator2> > result_type; | |
// zip everything together | |
result_type r = | |
thrust::merge(thrust::make_zip_iterator(thrust::make_tuple(keys_first1, values_first1)), | |
thrust::make_zip_iterator(thrust::make_tuple(keys_last1, values_first1)), | |
thrust::make_zip_iterator(thrust::make_tuple(keys_first2, values_first2)), | |
thrust::make_zip_iterator(thrust::make_tuple(keys_last2, values_first2)), | |
thrust::make_zip_iterator(thrust::make_tuple(keys_result, values_result)), | |
make_compare_first(comp)); | |
thrust::tuple<OutputIterator1,OutputIterator2> tup = r.get_iterator_tuple(); | |
return thrust::make_pair(thrust::get<0>(tup), thrust::get<1>(tup)); | |
} | |
template<typename ForwardIterator1, | |
typename ForwardIterator2, | |
typename ForwardIterator3, | |
typename ForwardIterator4, | |
typename OutputIterator1, | |
typename OutputIterator2> | |
thrust::pair<OutputIterator1,OutputIterator2> | |
merge_by_key(ForwardIterator1 keys_first1, | |
ForwardIterator1 keys_last1, | |
ForwardIterator2 values_first1, | |
ForwardIterator3 keys_first2, | |
ForwardIterator3 keys_last2, | |
ForwardIterator4 values_first2, | |
OutputIterator1 keys_result, | |
OutputIterator2 values_result) | |
{ | |
typedef typename thrust::iterator_value<ForwardIterator1>::type value_type; | |
return merge_by_key(keys_first1, keys_last1, values_first1, keys_first2, keys_last2, values_first2, keys_result, values_result, thrust::less<value_type>()); | |
} | |
int main() | |
{ | |
int keys1[5] = {1,2,4,7,9}; | |
int values1[5] = {0,0,1,1,2}; | |
int keys2[7] = {0,3,5,6,8,10,11}; | |
int values2[7] = {0,0,1,1,2,2,2}; | |
int keys_result[12]; | |
int values_result[12]; | |
merge_by_key(keys1, keys1 + 5, values1, keys2, keys2 + 7, values2, keys_result, values_result); | |
std::cout << "keys_result: "; | |
std::copy(keys_result, keys_result + 12, std::ostream_iterator<int>(std::cout, " ")); | |
std::cout << std::endl; | |
std::cout << "values_result: "; | |
std::copy(values_result, values_result + 12, std::ostream_iterator<int>(std::cout, " ")); | |
std::cout << std::endl; | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment