Skip to content

Instantly share code, notes, and snippets.

@NichtsHsu
Last active September 11, 2023 09:18
Show Gist options
  • Save NichtsHsu/030271c510249edd3bbd11cb90468003 to your computer and use it in GitHub Desktop.
Save NichtsHsu/030271c510249edd3bbd11cb90468003 to your computer and use it in GitHub Desktop.
C++ curry implementation
#include <functional>
#include <iostream>
#include <tuple>
template <typename... Tuples>
using tuple_cat_result = decltype(std::tuple_cat(std::declval<Tuples>()...));
template <typename IdxSeq, typename... Args>
struct args_head {};
template <typename... Args>
struct args_head<std::integer_sequence<size_t>, Args...> {
using type = std::tuple<>;
};
template <typename T, typename... Rest, size_t N, size_t... Idx>
struct args_head<std::integer_sequence<size_t, N, Idx...>, T, Rest...> {
using type = tuple_cat_result<
std::tuple<T>,
typename args_head<std::integer_sequence<size_t, Idx...>, Rest...>::type>;
};
template <size_t N, typename... Args>
using args_head_t =
typename args_head<std::make_index_sequence<N>, Args...>::type;
template <typename IdxSeq, typename... Args>
struct args_except {};
template <typename... Args>
struct args_except<std::integer_sequence<size_t>, Args...> {
using type = std::tuple<Args...>;
};
template <typename T, typename... Rest, size_t N, size_t... Idx>
struct args_except<std::integer_sequence<size_t, N, Idx...>, T, Rest...> {
using type = typename args_except<std::integer_sequence<size_t, Idx...>,
Rest...>::type;
};
template <size_t N, typename... Args>
using args_except_t =
typename args_except<std::make_index_sequence<N>, Args...>::type;
template <typename Callable, typename ArgsTuple, typename UncurriedArgsTuple>
class Curried {};
template <typename Callable, typename... CurriedArgs, typename... UncurriedArgs>
class Curried<Callable, std::tuple<CurriedArgs...>,
std::tuple<UncurriedArgs...>> {
public:
template <typename CallableT>
Curried(CallableT &&callable, std::tuple<CurriedArgs...> &&args)
: _callable(std::forward<CallableT>(callable)),
_curriedArgs(std::move(args)) {}
decltype(auto) operator()(UncurriedArgs... args) const {
auto uncurriedArgs =
std::tuple<UncurriedArgs...>(std::forward<UncurriedArgs>(args)...);
return std::apply(_callable,
std::tuple_cat(_curriedArgs, std::move(uncurriedArgs)));
};
template <typename... Args>
auto curry(std::tuple<Args...> &&args) const & {
using NewCurried = Curried<
Callable,
tuple_cat_result<std::tuple<CurriedArgs...>, std::tuple<Args...>>,
args_except_t<sizeof...(Args), UncurriedArgs...>>;
return NewCurried(_callable, std::tuple_cat(_curriedArgs, args));
}
template <typename... Args>
auto curry(std::tuple<Args...> &&args) && {
using NewCurried = Curried<
Callable,
tuple_cat_result<std::tuple<CurriedArgs...>, std::tuple<Args...>>,
args_except_t<sizeof...(Args), UncurriedArgs...>>;
return NewCurried(std::move(_callable),
std::tuple_cat(std::move(_curriedArgs), args));
}
private:
Callable _callable;
mutable std::tuple<CurriedArgs...> _curriedArgs;
};
template <typename CurriedArgsTuple, typename UncurriedArgsTuple>
struct Curry {};
template <typename... CurriedArgs, typename... UncurriedArgs>
struct Curry<std::tuple<CurriedArgs...>, std::tuple<UncurriedArgs...>> {
template <typename Ret, typename... Args>
static auto curry_inner(
const std::function<Ret(CurriedArgs..., UncurriedArgs...)> &callable,
Args &&...args) {
using Callable = std::function<Ret(CurriedArgs..., UncurriedArgs...)>;
using CurriedT = Curried<Callable, std::tuple<CurriedArgs...>,
std::tuple<UncurriedArgs...>>;
return CurriedT(callable,
std::tuple<CurriedArgs...>(std::forward<Args>(args)...));
}
template <typename Ret, typename... Args>
static auto curry_inner(
std::function<Ret(CurriedArgs..., UncurriedArgs...)> &&callable,
Args &&...args) {
using Callable = std::function<Ret(CurriedArgs..., UncurriedArgs...)>;
using CurriedT = Curried<Callable, std::tuple<CurriedArgs...>,
std::tuple<UncurriedArgs...>>;
return CurriedT(std::move(callable),
std::tuple<CurriedArgs...>(std::forward<Args>(args)...));
}
};
template <typename Ret, typename... FullArgs, typename... Args>
auto curry(std::function<Ret(FullArgs...)> &&callable, Args &&...args) {
using CurriedArgsTuple = args_head_t<sizeof...(Args), FullArgs...>;
using UncurriedArgsTuple = args_except_t<sizeof...(Args), FullArgs...>;
using CurryWrapper = Curry<CurriedArgsTuple, UncurriedArgsTuple>;
return CurryWrapper::curry_inner(std::move(callable),
std::forward<Args>(args)...);
}
template <typename Ret, typename... FullArgs, typename... Args>
auto curry(const std::function<Ret(FullArgs...)> &callable, Args &&...args) {
using CurriedArgsTuple = args_head_t<sizeof...(Args), FullArgs...>;
using UncurriedArgsTuple = args_except_t<sizeof...(Args), FullArgs...>;
using CurryWrapper = Curry<CurriedArgsTuple, UncurriedArgsTuple>;
return CurryWrapper::curry_inner(callable, std::forward<Args>(args)...);
}
template <typename Ret, typename... FullArgs, typename... Args>
auto curry(Ret (*func)(FullArgs...), Args &&...args) {
std::function<Ret(FullArgs...)> callable = func;
return curry(std::move(callable), std::forward<Args>(args)...);
}
template <typename Ret, typename T, typename... FullArgs>
auto curry(Ret (T::*func)(FullArgs...)) {
std::function<Ret(T *, FullArgs...)> callable = func;
return curry(std::move(callable));
}
template <typename Ret, typename T, typename... FullArgs, typename... Args>
auto curry(Ret (T::*func)(FullArgs...), T *caller, Args &&...args) {
std::function<Ret(T *, FullArgs...)> callable = func;
return curry(std::move(callable), caller, std::forward<Args>(args)...);
}
template <typename... CurriedArgsList, typename... Args>
auto curry(const Curried<CurriedArgsList...> &curried, Args &&...args) {
return curried.curry(std::tuple<Args...>(std::forward<Args>(args)...));
}
template <typename... CurriedArgsList, typename... Args>
auto curry(Curried<CurriedArgsList...> &&curried, Args &&...args) {
return curried.curry(std::tuple<Args...>(std::forward<Args>(args)...));
}
struct Test {
int a;
void print() { std::cout << a << std::endl; }
};
void print(int a, double b, char c, const char *d) {
std::cout << a << " " << b << " " << c << " " << d << std::endl;
}
void add(int &x, int y) { x += y; }
int &retref(int &x) { return x; }
int main() {
auto curriedNoArg = curry(print);
curriedNoArg(1919, 114.514, 'a', "hell word");
auto curriedFullArgs = curry(print, 1919, 114.514, 'a', "hell word");
curriedFullArgs();
auto curriedOneArg = curry(print, 1919);
curriedOneArg(114.514, 'a', "hell word");
auto curriedOfCurried = curry(curriedOneArg, 114.514);
curriedOfCurried('a', "hell word");
Test t = {810};
auto memberFuncCurriedNoArg = curry(&Test::print);
memberFuncCurriedNoArg(&t);
auto memberFuncCurriedSelfArg = curry(&Test::print, &t);
memberFuncCurriedSelfArg();
int x = 1;
auto curriedWithRef = curry(add, x);
curriedWithRef(2);
std::cout << x << std::endl;
auto curriedRetRef = curry(retref, x);
curriedRetRef() = 5;
std::cout << x << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment