Skip to content

Instantly share code, notes, and snippets.

@jaredhoberock
Created April 6, 2012 06:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaredhoberock/2317505 to your computer and use it in GitHub Desktop.
Save jaredhoberock/2317505 to your computer and use it in GitHub Desktop.
How to build merge_by_key using thrust::merge and thrust::zip_iterator
#include <thrust/pair.h>
#include <thrust/tuple.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/functional.h>
#include <thrust/merge.h>
#include <iostream>
#include <algorithm>
// compares the 0th elements of two tuples
template<typename Compare>
struct compare_first
{
compare_first(Compare comp)
: comp(comp)
{}
template<typename Tuple1, typename Tuple2>
__host__ __device__
bool operator()(const Tuple1 &lhs, const Tuple2 &rhs)
{
return comp(thrust::get<0>(lhs), thrust::get<0>(rhs));
}
Compare comp;
};
template<typename Compare>
compare_first<Compare> make_compare_first(Compare comp)
{
return compare_first<Compare>(comp);
}
template<typename ForwardIterator1,
typename ForwardIterator2,
typename ForwardIterator3,
typename ForwardIterator4,
typename OutputIterator1,
typename OutputIterator2,
typename Compare>
thrust::pair<OutputIterator1,OutputIterator2>
merge_by_key(ForwardIterator1 keys_first1,
ForwardIterator1 keys_last1,
ForwardIterator2 values_first1,
ForwardIterator3 keys_first2,
ForwardIterator3 keys_last2,
ForwardIterator4 values_first2,
OutputIterator1 keys_result,
OutputIterator2 values_result,
Compare comp)
{
typedef thrust::zip_iterator<thrust::tuple<OutputIterator1,OutputIterator2> > result_type;
// zip everything together
result_type r =
thrust::merge(thrust::make_zip_iterator(thrust::make_tuple(keys_first1, values_first1)),
thrust::make_zip_iterator(thrust::make_tuple(keys_last1, values_first1)),
thrust::make_zip_iterator(thrust::make_tuple(keys_first2, values_first2)),
thrust::make_zip_iterator(thrust::make_tuple(keys_last2, values_first2)),
thrust::make_zip_iterator(thrust::make_tuple(keys_result, values_result)),
make_compare_first(comp));
thrust::tuple<OutputIterator1,OutputIterator2> tup = r.get_iterator_tuple();
return thrust::make_pair(thrust::get<0>(tup), thrust::get<1>(tup));
}
template<typename ForwardIterator1,
typename ForwardIterator2,
typename ForwardIterator3,
typename ForwardIterator4,
typename OutputIterator1,
typename OutputIterator2>
thrust::pair<OutputIterator1,OutputIterator2>
merge_by_key(ForwardIterator1 keys_first1,
ForwardIterator1 keys_last1,
ForwardIterator2 values_first1,
ForwardIterator3 keys_first2,
ForwardIterator3 keys_last2,
ForwardIterator4 values_first2,
OutputIterator1 keys_result,
OutputIterator2 values_result)
{
typedef typename thrust::iterator_value<ForwardIterator1>::type value_type;
return merge_by_key(keys_first1, keys_last1, values_first1, keys_first2, keys_last2, values_first2, keys_result, values_result, thrust::less<value_type>());
}
int main()
{
int keys1[5] = {1,2,4,7,9};
int values1[5] = {0,0,1,1,2};
int keys2[7] = {0,3,5,6,8,10,11};
int values2[7] = {0,0,1,1,2,2,2};
int keys_result[12];
int values_result[12];
merge_by_key(keys1, keys1 + 5, values1, keys2, keys2 + 7, values2, keys_result, values_result);
std::cout << "keys_result: ";
std::copy(keys_result, keys_result + 12, std::ostream_iterator<int>(std::cout, " "));
std::cout << std::endl;
std::cout << "values_result: ";
std::copy(values_result, values_result + 12, std::ostream_iterator<int>(std::cout, " "));
std::cout << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment