-
-
Save NichtsHsu/030271c510249edd3bbd11cb90468003 to your computer and use it in GitHub Desktop.
C++ curry implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <functional> | |
#include <iostream> | |
#include <tuple> | |
template <typename... Tuples> | |
using tuple_cat_result = decltype(std::tuple_cat(std::declval<Tuples>()...)); | |
template <typename IdxSeq, typename... Args> | |
struct args_head {}; | |
template <typename... Args> | |
struct args_head<std::integer_sequence<size_t>, Args...> { | |
using type = std::tuple<>; | |
}; | |
template <typename T, typename... Rest, size_t N, size_t... Idx> | |
struct args_head<std::integer_sequence<size_t, N, Idx...>, T, Rest...> { | |
using type = tuple_cat_result< | |
std::tuple<T>, | |
typename args_head<std::integer_sequence<size_t, Idx...>, Rest...>::type>; | |
}; | |
template <size_t N, typename... Args> | |
using args_head_t = | |
typename args_head<std::make_index_sequence<N>, Args...>::type; | |
template <typename IdxSeq, typename... Args> | |
struct args_except {}; | |
template <typename... Args> | |
struct args_except<std::integer_sequence<size_t>, Args...> { | |
using type = std::tuple<Args...>; | |
}; | |
template <typename T, typename... Rest, size_t N, size_t... Idx> | |
struct args_except<std::integer_sequence<size_t, N, Idx...>, T, Rest...> { | |
using type = typename args_except<std::integer_sequence<size_t, Idx...>, | |
Rest...>::type; | |
}; | |
template <size_t N, typename... Args> | |
using args_except_t = | |
typename args_except<std::make_index_sequence<N>, Args...>::type; | |
template <typename Callable, typename ArgsTuple, typename UncurriedArgsTuple> | |
class Curried {}; | |
template <typename Callable, typename... CurriedArgs, typename... UncurriedArgs> | |
class Curried<Callable, std::tuple<CurriedArgs...>, | |
std::tuple<UncurriedArgs...>> { | |
public: | |
template <typename CallableT> | |
Curried(CallableT &&callable, std::tuple<CurriedArgs...> &&args) | |
: _callable(std::forward<CallableT>(callable)), | |
_curriedArgs(std::move(args)) {} | |
decltype(auto) operator()(UncurriedArgs... args) const { | |
auto uncurriedArgs = | |
std::tuple<UncurriedArgs...>(std::forward<UncurriedArgs>(args)...); | |
return std::apply(_callable, | |
std::tuple_cat(_curriedArgs, std::move(uncurriedArgs))); | |
}; | |
template <typename... Args> | |
auto curry(std::tuple<Args...> &&args) const & { | |
using NewCurried = Curried< | |
Callable, | |
tuple_cat_result<std::tuple<CurriedArgs...>, std::tuple<Args...>>, | |
args_except_t<sizeof...(Args), UncurriedArgs...>>; | |
return NewCurried(_callable, std::tuple_cat(_curriedArgs, args)); | |
} | |
template <typename... Args> | |
auto curry(std::tuple<Args...> &&args) && { | |
using NewCurried = Curried< | |
Callable, | |
tuple_cat_result<std::tuple<CurriedArgs...>, std::tuple<Args...>>, | |
args_except_t<sizeof...(Args), UncurriedArgs...>>; | |
return NewCurried(std::move(_callable), | |
std::tuple_cat(std::move(_curriedArgs), args)); | |
} | |
private: | |
Callable _callable; | |
mutable std::tuple<CurriedArgs...> _curriedArgs; | |
}; | |
template <typename CurriedArgsTuple, typename UncurriedArgsTuple> | |
struct Curry {}; | |
template <typename... CurriedArgs, typename... UncurriedArgs> | |
struct Curry<std::tuple<CurriedArgs...>, std::tuple<UncurriedArgs...>> { | |
template <typename Ret, typename... Args> | |
static auto curry_inner( | |
const std::function<Ret(CurriedArgs..., UncurriedArgs...)> &callable, | |
Args &&...args) { | |
using Callable = std::function<Ret(CurriedArgs..., UncurriedArgs...)>; | |
using CurriedT = Curried<Callable, std::tuple<CurriedArgs...>, | |
std::tuple<UncurriedArgs...>>; | |
return CurriedT(callable, | |
std::tuple<CurriedArgs...>(std::forward<Args>(args)...)); | |
} | |
template <typename Ret, typename... Args> | |
static auto curry_inner( | |
std::function<Ret(CurriedArgs..., UncurriedArgs...)> &&callable, | |
Args &&...args) { | |
using Callable = std::function<Ret(CurriedArgs..., UncurriedArgs...)>; | |
using CurriedT = Curried<Callable, std::tuple<CurriedArgs...>, | |
std::tuple<UncurriedArgs...>>; | |
return CurriedT(std::move(callable), | |
std::tuple<CurriedArgs...>(std::forward<Args>(args)...)); | |
} | |
}; | |
template <typename Ret, typename... FullArgs, typename... Args> | |
auto curry(std::function<Ret(FullArgs...)> &&callable, Args &&...args) { | |
using CurriedArgsTuple = args_head_t<sizeof...(Args), FullArgs...>; | |
using UncurriedArgsTuple = args_except_t<sizeof...(Args), FullArgs...>; | |
using CurryWrapper = Curry<CurriedArgsTuple, UncurriedArgsTuple>; | |
return CurryWrapper::curry_inner(std::move(callable), | |
std::forward<Args>(args)...); | |
} | |
template <typename Ret, typename... FullArgs, typename... Args> | |
auto curry(const std::function<Ret(FullArgs...)> &callable, Args &&...args) { | |
using CurriedArgsTuple = args_head_t<sizeof...(Args), FullArgs...>; | |
using UncurriedArgsTuple = args_except_t<sizeof...(Args), FullArgs...>; | |
using CurryWrapper = Curry<CurriedArgsTuple, UncurriedArgsTuple>; | |
return CurryWrapper::curry_inner(callable, std::forward<Args>(args)...); | |
} | |
template <typename Ret, typename... FullArgs, typename... Args> | |
auto curry(Ret (*func)(FullArgs...), Args &&...args) { | |
std::function<Ret(FullArgs...)> callable = func; | |
return curry(std::move(callable), std::forward<Args>(args)...); | |
} | |
template <typename Ret, typename T, typename... FullArgs> | |
auto curry(Ret (T::*func)(FullArgs...)) { | |
std::function<Ret(T *, FullArgs...)> callable = func; | |
return curry(std::move(callable)); | |
} | |
template <typename Ret, typename T, typename... FullArgs, typename... Args> | |
auto curry(Ret (T::*func)(FullArgs...), T *caller, Args &&...args) { | |
std::function<Ret(T *, FullArgs...)> callable = func; | |
return curry(std::move(callable), caller, std::forward<Args>(args)...); | |
} | |
template <typename... CurriedArgsList, typename... Args> | |
auto curry(const Curried<CurriedArgsList...> &curried, Args &&...args) { | |
return curried.curry(std::tuple<Args...>(std::forward<Args>(args)...)); | |
} | |
template <typename... CurriedArgsList, typename... Args> | |
auto curry(Curried<CurriedArgsList...> &&curried, Args &&...args) { | |
return curried.curry(std::tuple<Args...>(std::forward<Args>(args)...)); | |
} | |
struct Test { | |
int a; | |
void print() { std::cout << a << std::endl; } | |
}; | |
void print(int a, double b, char c, const char *d) { | |
std::cout << a << " " << b << " " << c << " " << d << std::endl; | |
} | |
void add(int &x, int y) { x += y; } | |
int &retref(int &x) { return x; } | |
int main() { | |
auto curriedNoArg = curry(print); | |
curriedNoArg(1919, 114.514, 'a', "hell word"); | |
auto curriedFullArgs = curry(print, 1919, 114.514, 'a', "hell word"); | |
curriedFullArgs(); | |
auto curriedOneArg = curry(print, 1919); | |
curriedOneArg(114.514, 'a', "hell word"); | |
auto curriedOfCurried = curry(curriedOneArg, 114.514); | |
curriedOfCurried('a', "hell word"); | |
Test t = {810}; | |
auto memberFuncCurriedNoArg = curry(&Test::print); | |
memberFuncCurriedNoArg(&t); | |
auto memberFuncCurriedSelfArg = curry(&Test::print, &t); | |
memberFuncCurriedSelfArg(); | |
int x = 1; | |
auto curriedWithRef = curry(add, x); | |
curriedWithRef(2); | |
std::cout << x << std::endl; | |
auto curriedRetRef = curry(retref, x); | |
curriedRetRef() = 5; | |
std::cout << x << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment