Skip to content

Instantly share code, notes, and snippets.

@NichtsHsu
Last active August 7, 2023 02:04
Show Gist options
  • Save NichtsHsu/9b2f490675c71cda641a3a137778f70e to your computer and use it in GitHub Desktop.
Save NichtsHsu/9b2f490675c71cda641a3a137778f70e to your computer and use it in GitHub Desktop.
C++ curry without any standard library or lambda
#include <iostream>
template <bool, typename T = void>
struct enable_if {};
template <typename T>
struct enable_if<true, T> {
using type = T;
};
template <bool Condition, typename T>
using enable_if_t = typename enable_if<Condition, T>::type;
template <typename, typename>
struct is_same {
constexpr static bool value = false;
};
template <typename T>
struct is_same<T, T> {
constexpr static bool value = true;
};
template <typename T, typename U>
constexpr static bool is_same_v = is_same<T, U>::value;
template <typename T>
struct remove_reference {
using type = T;
};
template <typename T>
struct remove_reference<T &> {
using type = T;
};
template <typename T>
struct remove_reference<T &&> {
using type = T;
};
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
struct normal_function {};
struct member_function {};
template <typename Callable>
struct callable_tag {
using type = normal_function;
};
template <typename F, typename T>
struct callable_tag<F T::*> {
using type = member_function;
};
template <typename Callable>
using callable_tag_t =
typename callable_tag<remove_reference_t<Callable>>::type;
template <std::size_t... Ns>
struct integer_sequence {};
template <std::size_t N, std::size_t... Is>
auto make_index_sequence_inner() {
if constexpr (N == 0)
return integer_sequence<Is...>();
else
return make_index_sequence_inner<N - 1, N - 1, Is...>();
}
template <std::size_t N>
using make_index_sequence =
remove_reference_t<decltype(make_index_sequence_inner<N>())>;
template <typename T>
constexpr T &&forward(remove_reference_t<T> &t) noexcept {
return static_cast<T &&>(t);
}
template <typename T>
constexpr T &&forward(remove_reference_t<T> &&t) noexcept {
return static_cast<T &&>(t);
}
template <typename T>
constexpr remove_reference_t<T> &&move(T &&t) noexcept {
return static_cast<remove_reference_t<T> &&>(t);
}
template <typename Callable, typename... Args>
decltype(auto) invoke_inner(normal_function, Callable &&callable,
Args &&...args) {
return callable(forward<Args>(args)...);
}
template <typename Callable, typename Caller, typename... Args>
decltype(auto) invoke_inner(member_function, Callable &&callable,
Caller &&caller, Args &&...args) {
return (caller->*callable)(forward<Args>(args)...);
}
template <typename Callable, typename... Args>
decltype(auto) invoke(Callable &&callable, Args &&...args) {
return invoke_inner(callable_tag_t<Callable>{}, callable,
forward<Args>(args)...);
}
template <typename T>
class Function {
Function() = delete;
};
template <typename Ret, typename... Args>
class Function<Ret(Args...)> {
private:
template <typename Callable>
static Ret invoke(void *callable, Args... args) {
return ::invoke(*reinterpret_cast<remove_reference_t<Callable> *>(callable),
forward<Args>(args)...);
}
template <typename Callable>
static void *copy(void *callable) {
return new remove_reference_t<Callable>(
*reinterpret_cast<remove_reference_t<Callable> *>(callable));
}
template <typename Callable>
static void free(void *callable) {
delete reinterpret_cast<remove_reference_t<Callable> *>(callable);
}
public:
template <typename Callable,
enable_if_t<!is_same_v<remove_reference_t<Callable>,
Function<Ret(Args...)>>,
bool> = true>
Function(Callable &&callable) {
_callable = new remove_reference_t<Callable>(forward<Callable>(callable));
_invoke = invoke<Callable>;
_free = free<Callable>;
_copy = copy<Callable>;
}
Function(const Function &other) {
_callable = other._copy(other._callable);
_invoke = other._invoke;
_free = other._free;
_copy = other._copy;
};
Function(Function &&other) {
_callable = other._callable;
other._callable = nullptr;
_invoke = other._invoke;
_free = other._free;
_copy = other._copy;
}
~Function() { _free(_callable); }
Ret operator()(Args... args) const {
return _invoke(_callable, forward<Args>(args)...);
}
private:
void *_callable;
Ret (*_invoke)(void *, Args...);
void (*_free)(void *);
void *(*_copy)(void *);
};
template <size_t Idx, typename T>
class tuple_storage {
public:
tuple_storage(const T &value) : value(value) {}
template <typename U>
tuple_storage(U &&value) : value(std::forward<U>(value)) {}
T value;
};
template <size_t Idx, typename... Elements>
class tuple_inner {};
template <size_t Idx, typename T>
class tuple_inner<Idx, T> : public tuple_storage<Idx, T> {
public:
tuple_inner(const T &value) : tuple_storage<Idx, T>(value) {}
template <typename U>
tuple_inner(U &&value) : tuple_storage<Idx, T>(std::forward<U>(value)) {}
};
template <size_t Idx, typename T, typename... Rest>
class tuple_inner<Idx, T, Rest...> : public tuple_inner<Idx + 1, Rest...>,
public tuple_storage<Idx, T> {
public:
tuple_inner(const T &value, const Rest &...rest)
: tuple_inner<Idx + 1, Rest...>(rest...), tuple_storage<Idx, T>(value) {}
template <typename U, typename... Inputs>
tuple_inner(U &&value, Inputs &&...rest)
: tuple_inner<Idx + 1, Rest...>(std::forward<Inputs>(rest)...),
tuple_storage<Idx, T>(std::forward<U>(value)) {}
};
template <typename... Elements>
class tuple : public tuple_inner<0, Elements...> {
public:
tuple(const Elements &...elements)
: tuple_inner<0, Elements...>(elements...) {}
template <typename... Inputs>
tuple(Inputs &&...elements)
: tuple_inner<0, Elements...>(std::forward<Inputs>(elements)...) {}
};
template <>
class tuple<void> {};
template <typename T>
struct tuple_size {};
template <typename... Elements>
struct tuple_size<tuple<Elements...>> {
constexpr static size_t size = sizeof...(Elements);
};
template <typename T>
constexpr size_t tuple_size_v = tuple_size<remove_reference_t<T>>::size;
template <typename...>
struct combine_tuples;
template <>
struct combine_tuples<> {
using type = tuple<>;
};
template <typename... T>
struct combine_tuples<tuple<T...>> {
using type = tuple<T...>;
};
template <typename... T1, typename... T2, typename... Rest>
struct combine_tuples<tuple<T1...>, tuple<T2...>, Rest...> {
using type = typename combine_tuples<tuple<T1..., T2...>, Rest...>::type;
};
template <typename... Tuples>
using tuple_cat_result =
typename combine_tuples<remove_reference_t<Tuples>...>::type;
template <typename...>
struct tuple_first_index_sequence;
template <>
struct tuple_first_index_sequence<> {
using type = integer_sequence<>;
};
template <typename First, typename... Rest>
struct tuple_first_index_sequence<First, Rest...> {
using type = make_index_sequence<tuple_size_v<First>>;
};
template <typename... Tuples>
using tuple_first_index_sequence_v =
typename tuple_first_index_sequence<Tuples...>::type;
template <size_t Idx, typename T>
T &get(tuple_storage<Idx, T> &t) {
return t.value;
}
template <size_t Idx, typename T>
const T &get(const tuple_storage<Idx, T> &t) {
return t.value;
}
template <size_t Idx, typename T>
T &&get(tuple_storage<Idx, T> &&t) {
return forward<T>(t.value);
}
template <size_t Idx, typename T>
const T &&get(const tuple_storage<Idx, T> &&t) {
return forward<T>(t.value);
}
template <typename Callable, typename Tuple, size_t... Idx>
decltype(auto) apply_inner(Callable &&callable, Tuple &&tuple,
integer_sequence<Idx...>) {
return invoke(forward<Callable>(callable),
get<Idx>(forward<Tuple>(tuple))...);
}
template <typename Callable, typename Tuple>
decltype(auto) apply(Callable &&callable, Tuple &&tuple) {
return apply_inner(forward<Callable>(callable), forward<Tuple>(tuple),
make_index_sequence<tuple_size_v<Tuple>>{});
}
template <typename Ret, typename IdxSeq, typename... Tuples>
struct TupleConcator {};
template <typename Ret>
struct TupleConcator<Ret, integer_sequence<>> {
template <typename... T>
static Ret tuple_cat_inner(T &&...values) {
return Ret(forward<T>(values)...);
}
};
template <typename Ret, size_t... Idx, typename Tuple, typename... OtherTuples>
struct TupleConcator<Ret, integer_sequence<Idx...>, Tuple, OtherTuples...> {
template <typename... T>
static Ret tuple_cat_inner(Tuple &&current, OtherTuples &&...others,
T &&...values) {
using next_idx = tuple_first_index_sequence_v<OtherTuples...>;
using next_concator = TupleConcator<Ret, next_idx, OtherTuples...>;
return next_concator::tuple_cat_inner(forward<OtherTuples>(others)...,
forward<T>(values)...,
get<Idx>(forward<Tuple>(current))...);
}
};
template <typename... Tuples>
auto tuple_cat(Tuples &&...tuples) -> tuple_cat_result<Tuples...> {
using first_idx = tuple_first_index_sequence_v<Tuples...>;
using tuple_ret = tuple_cat_result<Tuples...>;
using first_concator = TupleConcator<tuple_ret, first_idx, Tuples...>;
return first_concator::tuple_cat_inner(forward<Tuples>(tuples)...);
}
template <typename IdxSeq, typename... Args>
struct args_head {};
template <typename... Args>
struct args_head<integer_sequence<>, Args...> {
using type = tuple<>;
};
template <typename T, typename... Rest, size_t N, size_t... Idx>
struct args_head<integer_sequence<N, Idx...>, T, Rest...> {
using type = tuple_cat_result<
tuple<T>, typename args_head<integer_sequence<Idx...>, Rest...>::type>;
};
template <size_t N, typename... Args>
using args_head_t = typename args_head<make_index_sequence<N>, Args...>::type;
template <typename IdxSeq, typename... Args>
struct args_except {};
template <typename... Args>
struct args_except<integer_sequence<>, Args...> {
using type = tuple<Args...>;
};
template <typename T, typename... Rest, size_t N, size_t... Idx>
struct args_except<integer_sequence<N, Idx...>, T, Rest...> {
using type = typename args_except<integer_sequence<Idx...>, Rest...>::type;
};
template <size_t N, typename... Args>
using args_except_t =
typename args_except<make_index_sequence<N>, Args...>::type;
template <typename Callable, typename ArgsTuple, typename UncurriedArgsTuple>
class Curried {};
template <typename Callable, typename... CurriedArgs, typename... UncurriedArgs>
class Curried<Callable, tuple<CurriedArgs...>, tuple<UncurriedArgs...>> {
public:
template <typename CallableT>
Curried(CallableT &&callable, tuple<CurriedArgs...> &&args)
: _callable(forward<CallableT>(callable)), _curriedArgs(move(args)) {}
decltype(auto) operator()(UncurriedArgs... args) const {
auto uncurriedArgs =
tuple<UncurriedArgs...>(forward<UncurriedArgs>(args)...);
return apply(_callable, tuple_cat(_curriedArgs, move(uncurriedArgs)));
};
template <typename... Args>
auto curry(tuple<Args...> &&args) const & {
using NewCurried =
Curried<Callable,
tuple_cat_result<tuple<CurriedArgs...>, tuple<Args...>>,
args_except_t<sizeof...(Args), UncurriedArgs...>>;
return NewCurried(_callable, tuple_cat(_curriedArgs, args));
}
template <typename... Args>
auto curry(tuple<Args...> &&args) && {
using NewCurried =
Curried<Callable,
tuple_cat_result<tuple<CurriedArgs...>, tuple<Args...>>,
args_except_t<sizeof...(Args), UncurriedArgs...>>;
return NewCurried(move(_callable), tuple_cat(move(_curriedArgs), args));
}
private:
Callable _callable;
mutable tuple<CurriedArgs...> _curriedArgs;
};
template <typename CurriedArgsTuple, typename UncurriedArgsTuple>
struct Curry {};
template <typename... CurriedArgs, typename... UncurriedArgs>
struct Curry<tuple<CurriedArgs...>, tuple<UncurriedArgs...>> {
template <typename Ret, typename... Args>
static auto curry_inner(
const Function<Ret(CurriedArgs..., UncurriedArgs...)> &callable,
Args &&...args) {
using Callable = Function<Ret(CurriedArgs..., UncurriedArgs...)>;
using CurriedT =
Curried<Callable, tuple<CurriedArgs...>, tuple<UncurriedArgs...>>;
return CurriedT(callable, tuple<CurriedArgs...>(forward<Args>(args)...));
}
template <typename Ret, typename... Args>
static auto curry_inner(
Function<Ret(CurriedArgs..., UncurriedArgs...)> &&callable,
Args &&...args) {
using Callable = Function<Ret(CurriedArgs..., UncurriedArgs...)>;
using CurriedT =
Curried<Callable, tuple<CurriedArgs...>, tuple<UncurriedArgs...>>;
return CurriedT(move(callable),
tuple<CurriedArgs...>(forward<Args>(args)...));
}
};
template <typename Ret, typename... FullArgs, typename... Args>
auto curry(Function<Ret(FullArgs...)> &&callable, Args &&...args) {
using CurriedArgsTuple = args_head_t<sizeof...(Args), FullArgs...>;
using UncurriedArgsTuple = args_except_t<sizeof...(Args), FullArgs...>;
using CurryWrapper = Curry<CurriedArgsTuple, UncurriedArgsTuple>;
return CurryWrapper::curry_inner(move(callable), forward<Args>(args)...);
}
template <typename Ret, typename... FullArgs, typename... Args>
auto curry(const Function<Ret(FullArgs...)> &callable, Args &&...args) {
using CurriedArgsTuple = args_head_t<sizeof...(Args), FullArgs...>;
using UncurriedArgsTuple = args_except_t<sizeof...(Args), FullArgs...>;
using CurryWrapper = Curry<CurriedArgsTuple, UncurriedArgsTuple>;
return CurryWrapper::curry_inner(callable, forward<Args>(args)...);
}
template <typename Ret, typename... FullArgs, typename... Args>
auto curry(Ret (*func)(FullArgs...), Args &&...args) {
Function<Ret(FullArgs...)> callable = func;
return curry(move(callable), forward<Args>(args)...);
}
template <typename Ret, typename T, typename... FullArgs>
auto curry(Ret (T::*func)(FullArgs...)) {
Function<Ret(T *, FullArgs...)> callable = func;
return curry(move(callable));
}
template <typename Ret, typename T, typename... FullArgs, typename... Args>
auto curry(Ret (T::*func)(FullArgs...), T *caller, Args &&...args) {
Function<Ret(T *, FullArgs...)> callable = func;
return curry(move(callable), caller, forward<Args>(args)...);
}
template <typename... CurriedArgsList, typename... Args>
auto curry(const Curried<CurriedArgsList...> &curried, Args &&...args) {
return curried.curry(tuple<Args...>(forward<Args>(args)...));
}
template <typename... CurriedArgsList, typename... Args>
auto curry(Curried<CurriedArgsList...> &&curried, Args &&...args) {
return curried.curry(tuple<Args...>(forward<Args>(args)...));
}
struct Test {
int a;
void print() { std::cout << a << std::endl; }
};
void print(int a, double b, char c, const char *d) {
std::cout << a << " " << b << " " << c << " " << d << std::endl;
}
void add(int &x, int y) { x += y; }
int &retref(int &x) { return x; }
int main() {
auto curriedNoArg = curry(print);
curriedNoArg(1919, 114.514, 'a', "hell word");
auto curriedFullArgs = curry(print, 1919, 114.514, 'a', "hell word");
curriedFullArgs();
auto curriedOneArg = curry(print, 1919);
curriedOneArg(114.514, 'a', "hell word");
auto curriedOfCurried = curry(curriedOneArg, 114.514);
curriedOfCurried('a', "hell word");
Test t = {810};
auto memberFuncCurriedNoArg = curry(&Test::print);
memberFuncCurriedNoArg(&t);
auto memberFuncCurriedSelfArg = curry(&Test::print, &t);
memberFuncCurriedSelfArg();
int x = 1;
auto curriedWithRef = curry(add, x);
curriedWithRef(2);
std::cout << x << std::endl;
auto curriedRetRef = curry(retref, x);
curriedRetRef() = 5;
std::cout << x << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment