Last active
April 17, 2023 00:45
-
-
Save thynson/bd280b05b8a766e8f6399073cdf94150 to your computer and use it in GitHub Desktop.
A general branchless comparison implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <functional> | |
#include <iostream> | |
#include <utility> | |
template<typename Comparator, typename T> | |
typename std::enable_if< | |
std::is_same< | |
bool, | |
decltype(std::declval<const Comparator &>()(std::declval<const T &>(), std::declval<const T &>())) | |
>::value, | |
char | |
>::type | |
branchless_compare_impl(const Comparator &comparator, const T &lhs, const T &rhs) | |
{ | |
return comparator(lhs, rhs); | |
} | |
template<typename Comparator, typename T, typename ...Remains> | |
typename std::enable_if< | |
std::is_same< | |
bool, | |
decltype(std::declval<const Comparator &>()(std::declval<const T &>(), std::declval<const T &>())) | |
>::value, | |
char | |
>::type | |
branchless_compare_impl(const Comparator &comparator, const T &lhs, const T &rhs, Remains ...remains) | |
{ | |
// first we map true to -1, false to 0, so that logical not can be turned into bitwise not | |
char lhs_less_rhs = ~(char(comparator(lhs, rhs)) - 1); | |
char rhs_less_lhs = ~(char(comparator(rhs, lhs)) - 1); | |
// then we use bitwise operation to fold the result to avoid short circuit | |
return lhs_less_rhs | (~rhs_less_lhs & branchless_compare_impl(std::forward<Remains>(remains)...)); | |
} | |
/** | |
* This function perform comparison on tuples of comparator, lhs, rhs, in a branchless way (no short circuit). | |
* The whole procedure will be branchless when the underlying comprators are also branchless. | |
* | |
* @return | |
*/ | |
template<typename ...Args> | |
bool branchless_compare(const Args... args) | |
{ | |
return branchless_compare_impl(args...) != (char) 0; | |
} | |
struct BranchlessComparable | |
{ | |
int a; | |
int b; | |
int c; | |
int d; | |
}; | |
template<> | |
struct std::less<BranchlessComparable> | |
{ | |
bool operator()(const BranchlessComparable &lhs, const BranchlessComparable &rhs) const; | |
}; | |
bool std::less<BranchlessComparable>::operator()(const BranchlessComparable &lhs, const BranchlessComparable &rhs) const | |
{ | |
return branchless_compare(std::less(), lhs.a, rhs.a, | |
std::less(), lhs.b, rhs.b, | |
std::less(), lhs.c, rhs.c, | |
std::less<int>(), lhs.d, rhs.d); | |
} | |
int main() | |
{ | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 3, 3}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 3, 5}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 2, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 4, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 1, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 3, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{0, 2, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{1, 2, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
std::cout << std::less<BranchlessComparable>()(BranchlessComparable{2, 2, 3, 4}, BranchlessComparable{1, 2, 3, 4}) | |
<< std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment