Last active
May 19, 2023 03:08
-
-
Save glennvl/c03845c3c6499af6e9aa8308e56f2247 to your computer and use it in GitHub Desktop.
Function wrapper that uses static memory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*! | |
* \file | |
* \brief Owning function wrapper that uses static memory | |
*/ | |
#ifndef STATIC_FUNCTION_HPP | |
#define STATIC_FUNCTION_HPP | |
#include <array> | |
#include <concepts> | |
#include <cstddef> | |
#include <cstdint> | |
#include <functional> | |
#include <memory> | |
#include <new> | |
#include <gsl/assert> | |
namespace util { | |
template <typename TFunction, std::size_t StorageSize = sizeof(std::uintptr_t) * 3u> | |
class StaticFunction; | |
template <std::size_t StorageSize, typename TRet, typename... TArgs> | |
class StaticFunction<TRet(TArgs...), StorageSize> | |
{ | |
public: | |
StaticFunction() = default; | |
explicit StaticFunction(std::nullptr_t) {} | |
template <typename TFunction> | |
requires std::is_invocable_r_v<TRet, TFunction, TArgs...> | |
and (not std::is_same_v<std::decay_t<TFunction>, StaticFunction>) | |
explicit StaticFunction(TFunction &&fn) : valid_{true} | |
{ | |
assign(std::forward<TFunction>(fn)); | |
} | |
~StaticFunction() { destroy(); } | |
StaticFunction(StaticFunction &&other) noexcept : valid_{other.valid_} | |
{ | |
if (valid_) | |
{ | |
other.get().move_to(storage()); | |
other.valid_ = false; | |
} | |
} | |
StaticFunction(StaticFunction const &other) : valid_{other.valid_} | |
{ | |
if (valid_) { other.get().copy_to(storage()); } | |
} | |
StaticFunction &operator=(StaticFunction &&rhs) noexcept | |
{ | |
auto copy{std::move(rhs)}; | |
swap(*this, copy); | |
return *this; | |
} | |
StaticFunction &operator=(StaticFunction const &rhs) | |
{ | |
auto copy{rhs}; | |
swap(*this, copy); | |
return *this; | |
} | |
template <typename TFunction> | |
requires std::is_invocable_r_v<TRet, TFunction, TArgs...> | |
and (not std::is_same_v<std::decay_t<TFunction>, StaticFunction>) | |
StaticFunction &operator=(TFunction &&fn) | |
{ | |
destroy(); | |
assign(std::forward<TFunction>(fn)); | |
return *this; | |
} | |
template <typename TFunction> | |
StaticFunction &operator=(std::reference_wrapper<TFunction> fn) | |
{ | |
destroy(); | |
assign(std::forward<TFunction>(fn)); | |
return *this; | |
} | |
StaticFunction &operator=(std::nullptr_t) | |
{ | |
destroy(); | |
return *this; | |
} | |
[[nodiscard]] bool valid() const { return valid_; } | |
[[nodiscard]] operator bool() const { return valid(); } | |
[[nodiscard]] bool operator!() const { return !valid(); } | |
TRet operator()(TArgs... args) const | |
{ | |
Expects(valid()); | |
return get().invoke(std::forward<TArgs>(args)...); | |
} | |
TRet operator()(TArgs... args) | |
{ | |
Expects(valid()); | |
return get().invoke(std::forward<TArgs>(args)...); | |
} | |
private: | |
struct FunctionConcept | |
{ | |
virtual ~FunctionConcept() = default; | |
virtual void move_to(FunctionConcept *dest) noexcept = 0; | |
virtual void copy_to(FunctionConcept *dest) const = 0; | |
virtual TRet invoke(TArgs... args) const = 0; | |
virtual TRet invoke(TArgs... args) = 0; | |
}; | |
template <typename TFunction> | |
class FunctionModel final : public FunctionConcept | |
{ | |
public: | |
template <typename TFunctionWithQualifiers> | |
requires std::same_as<TFunction, std::decay_t<TFunctionWithQualifiers>> | |
explicit FunctionModel(TFunctionWithQualifiers &&fn) : fn_{std::forward<TFunctionWithQualifiers>(fn)} | |
{ | |
} | |
FunctionModel(FunctionModel &&) noexcept = default; | |
FunctionModel(FunctionModel const &) = default; | |
FunctionModel &operator=(FunctionModel &&) noexcept = default; | |
FunctionModel &operator=(FunctionModel const &) = default; | |
~FunctionModel() = default; | |
void move_to(FunctionConcept *dest) noexcept override | |
{ | |
std::construct_at(reinterpret_cast<FunctionModel *>(dest), std::move(*this)); | |
} | |
void copy_to(FunctionConcept *dest) const override | |
{ | |
std::construct_at(reinterpret_cast<FunctionModel *>(dest), *this); | |
} | |
TRet invoke(TArgs... args) const override { return fn_(std::forward<TArgs>(args)...); } | |
TRet invoke(TArgs... args) override { return fn_(std::forward<TArgs>(args)...); } | |
private: | |
TFunction fn_; | |
}; | |
friend void swap(StaticFunction &lhs, StaticFunction &rhs) noexcept | |
{ | |
std::swap(lhs.storage_, rhs.storage_); | |
std::swap(lhs.valid_, rhs.valid_); | |
} | |
[[nodiscard]] friend bool operator==(StaticFunction const &fn, std::nullptr_t) { return !fn; } | |
template <typename TFunction> | |
void assign(TFunction &&fn) | |
{ | |
using ThisType = StaticFunction<TRet(TArgs...), StorageSize>; | |
using FuncType = std::decay_t<TFunction>; | |
using ModelType = FunctionModel<FuncType>; | |
static_assert(not std::is_same_v<ThisType, FuncType>, "Cannot store StaticFunction inside itself"); | |
static_assert(sizeof(ModelType) <= storage_size_, "Function does not fit, increase storage size"); | |
static_assert(alignof(FunctionConcept) == alignof(ModelType), "Cannot accommodate alignment requirements"); | |
std::construct_at(reinterpret_cast<ModelType *>(storage()), std::forward<TFunction>(fn)); | |
valid_ = true; | |
} | |
void destroy() | |
{ | |
if (valid_) | |
{ | |
std::destroy_at(&get()); | |
valid_ = false; | |
} | |
} | |
[[nodiscard]] FunctionConcept const &get() const { return *std::launder(storage()); } | |
[[nodiscard]] FunctionConcept &get() { return *std::launder(storage()); } | |
[[nodiscard]] FunctionConcept const *storage() const | |
{ | |
return reinterpret_cast<FunctionConcept const *>(storage_.data()); | |
} | |
[[nodiscard]] FunctionConcept *storage() { return reinterpret_cast<FunctionConcept *>(storage_.data()); } | |
static constexpr auto storage_size_{StorageSize + sizeof(FunctionConcept)}; | |
alignas(FunctionConcept) std::array<std::byte, storage_size_> storage_{}; | |
bool valid_{false}; | |
}; | |
} // namespace util | |
#endif // STATIC_FUNCTION_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdlib> | |
#include <exception> | |
#include <functional> | |
#include <utility> | |
#include <catch2/catch_test_macros.hpp> | |
#include <catch_ext/death_test.hpp> | |
#include <util/static_function.hpp> | |
using util::StaticFunction; | |
namespace { | |
int function(int some_y) | |
{ | |
return 5 + some_y; | |
} | |
class Functor | |
{ | |
public: | |
explicit Functor(int some_x) : some_x_{some_x} {} | |
int operator()(int some_y) const { return some_x_ + some_y; } | |
private: | |
int some_x_; | |
}; | |
auto const lambda = [some_x = 5](int some_y) { | |
return some_x + some_y; | |
}; | |
} // anonymous namespace | |
TEST_CASE("can store all sorts of functions", "[util][StaticFunction]") | |
{ | |
SECTION("works with regular function") | |
{ | |
StaticFunction<int(int)> func{&function}; | |
REQUIRE(func(1) == 6); | |
REQUIRE(std::as_const(func)(1) == 6); | |
} | |
SECTION("works with member function") | |
{ | |
Functor functor{5}; | |
StaticFunction<int(int)> func{ | |
std::bind(&Functor::operator(), &functor, std::placeholders::_1)}; // NOLINT(modernize-avoid-bind) | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("works with functor") | |
{ | |
StaticFunction<int(int)> func{Functor{5}}; | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("works with lambda") | |
{ | |
StaticFunction<int(int)> func{lambda}; | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("works with multiple arguments") | |
{ | |
int some_x = 5; | |
StaticFunction<void()> func_1{[] { | |
}}; | |
StaticFunction<int(int, int)> func_2{[](int some_a, int some_b) { | |
return some_a + some_b; | |
}}; | |
StaticFunction<void(int)> func_3{[&some_x](int some_y) { | |
some_x += some_y; | |
}}; | |
REQUIRE(!!func_1); | |
REQUIRE((func_1(), true)); | |
REQUIRE(!!func_2); | |
REQUIRE(func_2(1, 2) == 3); | |
REQUIRE(!!func_3); | |
REQUIRE((func_3(1), some_x) == 6); | |
} | |
} | |
TEST_CASE("can be queried for empty", "[util][StaticFunction]") | |
{ | |
SECTION("is empty when default constructed") | |
{ | |
StaticFunction<int(int)> func; | |
REQUIRE(!func); | |
REQUIRE(func == nullptr); | |
REQUIRE(!static_cast<bool>(func)); | |
} | |
SECTION("is empty when constructed from null pointer") | |
{ | |
StaticFunction<int(int)> func(nullptr); | |
REQUIRE(!func); | |
REQUIRE(func == nullptr); | |
REQUIRE(!static_cast<bool>(func)); | |
} | |
SECTION("is not empty when function is stored") | |
{ | |
StaticFunction<int(int)> func{&function}; | |
REQUIRE(!!func); | |
REQUIRE(func != nullptr); | |
REQUIRE(static_cast<bool>(func)); | |
} | |
} | |
TEST_CASE("copy and move construction and assignment", "[util][StaticFunction]") | |
{ | |
SECTION("move construction") | |
{ | |
StaticFunction<int(int)> func_1{&function}; | |
StaticFunction<int(int)> func_2{std::move(func_1)}; | |
REQUIRE(func_2); | |
REQUIRE(func_2(1) == 6); | |
} | |
SECTION("copy construction") | |
{ | |
StaticFunction<int(int)> func_1{&function}; | |
StaticFunction<int(int)> func_2{func_1}; | |
REQUIRE(!!func_1); | |
REQUIRE(func_1(1) == 6); | |
REQUIRE(!!func_2); | |
REQUIRE(func_2(1) == 6); | |
} | |
SECTION("move assignment") | |
{ | |
StaticFunction<int(int)> func_1{&function}; | |
StaticFunction<int(int)> func_2; | |
func_2 = std::move(func_1); | |
REQUIRE(func_2); | |
REQUIRE(func_2(1) == 6); | |
} | |
SECTION("copy assignment") | |
{ | |
StaticFunction<int(int)> func_1{&function}; | |
StaticFunction<int(int)> func_2; | |
func_2 = func_1; | |
REQUIRE(!!func_1); | |
REQUIRE(func_1(1) == 6); | |
REQUIRE(!!func_2); | |
REQUIRE(func_2(1) == 6); | |
} | |
SECTION("assign after move") | |
{ | |
StaticFunction<int(int)> func_1{&function}; | |
StaticFunction<int(int)> func_2; | |
func_2 = std::move(func_1); | |
func_1 = func_2; | |
REQUIRE(!!func_1); | |
REQUIRE(func_1(1) == 6); | |
} | |
} | |
TEST_CASE("assign a new function", "[util][StaticFunction]") | |
{ | |
SECTION("function is cleared when assigning null pointer") | |
{ | |
StaticFunction<int(int)> func{&function}; | |
func = nullptr; | |
REQUIRE(!func); | |
} | |
SECTION("regular function can be assigned") | |
{ | |
StaticFunction<int(int)> func; | |
func = &function; | |
REQUIRE(!!func); | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("member function can be assigned") | |
{ | |
StaticFunction<int(int)> func; | |
Functor functor{5}; | |
func = std::bind(&Functor::operator(), &functor, std::placeholders::_1); // NOLINT(modernize-avoid-bind) | |
REQUIRE(!!func); | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("functor can be assigned") | |
{ | |
StaticFunction<int(int)> func; | |
func = Functor{5}; | |
REQUIRE(!!func); | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("lambda can be assigned") | |
{ | |
StaticFunction<int(int)> func; | |
func = lambda; | |
REQUIRE(!!func); | |
REQUIRE(func(1) == 6); | |
} | |
SECTION("reference wrapped function can be assigned") | |
{ | |
auto functor = Functor{5}; | |
std::reference_wrapper const wrapped{functor}; | |
StaticFunction<int(int)> func; | |
func = wrapped; | |
REQUIRE(!!func); | |
REQUIRE(func(1) == 6); | |
} | |
} | |
TEST_CASE("invoke function", "[util][StaticFunction]") | |
{ | |
std::set_terminate(std::abort); | |
SECTION("invoking empty function violates precondition") | |
{ | |
REQUIRE_DEATH([] { | |
StaticFunction<void()> func{}; | |
func(); | |
}); | |
REQUIRE_DEATH([] { | |
StaticFunction<void()> const func{}; | |
func(); | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Tested on compiler explorer with
-std=c++20 -O3 -Wall -Wextra -Wpedantic -Wshadow -Wconversion -Wsign-conversion -Wdouble-promotion -Werror -lCatch2Main -fsanitize=address,undefined
Catch2 3.0.0-preview3
std::construct_at
+ makeoperator==
hidden friend