Skip to content

Instantly share code, notes, and snippets.

@glennvl
Last active May 19, 2023 03:08
Show Gist options
  • Save glennvl/c03845c3c6499af6e9aa8308e56f2247 to your computer and use it in GitHub Desktop.
Save glennvl/c03845c3c6499af6e9aa8308e56f2247 to your computer and use it in GitHub Desktop.
Function wrapper that uses static memory
/*!
* \file
* \brief Owning function wrapper that uses static memory
*/
#ifndef STATIC_FUNCTION_HPP
#define STATIC_FUNCTION_HPP
#include <array>
#include <concepts>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <new>
#include <gsl/assert>
namespace util {
template <typename TFunction, std::size_t StorageSize = sizeof(std::uintptr_t) * 3u>
class StaticFunction;
template <std::size_t StorageSize, typename TRet, typename... TArgs>
class StaticFunction<TRet(TArgs...), StorageSize>
{
public:
StaticFunction() = default;
explicit StaticFunction(std::nullptr_t) {}
template <typename TFunction>
requires std::is_invocable_r_v<TRet, TFunction, TArgs...>
and (not std::is_same_v<std::decay_t<TFunction>, StaticFunction>)
explicit StaticFunction(TFunction &&fn) : valid_{true}
{
assign(std::forward<TFunction>(fn));
}
~StaticFunction() { destroy(); }
StaticFunction(StaticFunction &&other) noexcept : valid_{other.valid_}
{
if (valid_)
{
other.get().move_to(storage());
other.valid_ = false;
}
}
StaticFunction(StaticFunction const &other) : valid_{other.valid_}
{
if (valid_) { other.get().copy_to(storage()); }
}
StaticFunction &operator=(StaticFunction &&rhs) noexcept
{
auto copy{std::move(rhs)};
swap(*this, copy);
return *this;
}
StaticFunction &operator=(StaticFunction const &rhs)
{
auto copy{rhs};
swap(*this, copy);
return *this;
}
template <typename TFunction>
requires std::is_invocable_r_v<TRet, TFunction, TArgs...>
and (not std::is_same_v<std::decay_t<TFunction>, StaticFunction>)
StaticFunction &operator=(TFunction &&fn)
{
destroy();
assign(std::forward<TFunction>(fn));
return *this;
}
template <typename TFunction>
StaticFunction &operator=(std::reference_wrapper<TFunction> fn)
{
destroy();
assign(std::forward<TFunction>(fn));
return *this;
}
StaticFunction &operator=(std::nullptr_t)
{
destroy();
return *this;
}
[[nodiscard]] bool valid() const { return valid_; }
[[nodiscard]] operator bool() const { return valid(); }
[[nodiscard]] bool operator!() const { return !valid(); }
TRet operator()(TArgs... args) const
{
Expects(valid());
return get().invoke(std::forward<TArgs>(args)...);
}
TRet operator()(TArgs... args)
{
Expects(valid());
return get().invoke(std::forward<TArgs>(args)...);
}
private:
struct FunctionConcept
{
virtual ~FunctionConcept() = default;
virtual void move_to(FunctionConcept *dest) noexcept = 0;
virtual void copy_to(FunctionConcept *dest) const = 0;
virtual TRet invoke(TArgs... args) const = 0;
virtual TRet invoke(TArgs... args) = 0;
};
template <typename TFunction>
class FunctionModel final : public FunctionConcept
{
public:
template <typename TFunctionWithQualifiers>
requires std::same_as<TFunction, std::decay_t<TFunctionWithQualifiers>>
explicit FunctionModel(TFunctionWithQualifiers &&fn) : fn_{std::forward<TFunctionWithQualifiers>(fn)}
{
}
FunctionModel(FunctionModel &&) noexcept = default;
FunctionModel(FunctionModel const &) = default;
FunctionModel &operator=(FunctionModel &&) noexcept = default;
FunctionModel &operator=(FunctionModel const &) = default;
~FunctionModel() = default;
void move_to(FunctionConcept *dest) noexcept override
{
std::construct_at(reinterpret_cast<FunctionModel *>(dest), std::move(*this));
}
void copy_to(FunctionConcept *dest) const override
{
std::construct_at(reinterpret_cast<FunctionModel *>(dest), *this);
}
TRet invoke(TArgs... args) const override { return fn_(std::forward<TArgs>(args)...); }
TRet invoke(TArgs... args) override { return fn_(std::forward<TArgs>(args)...); }
private:
TFunction fn_;
};
friend void swap(StaticFunction &lhs, StaticFunction &rhs) noexcept
{
std::swap(lhs.storage_, rhs.storage_);
std::swap(lhs.valid_, rhs.valid_);
}
[[nodiscard]] friend bool operator==(StaticFunction const &fn, std::nullptr_t) { return !fn; }
template <typename TFunction>
void assign(TFunction &&fn)
{
using ThisType = StaticFunction<TRet(TArgs...), StorageSize>;
using FuncType = std::decay_t<TFunction>;
using ModelType = FunctionModel<FuncType>;
static_assert(not std::is_same_v<ThisType, FuncType>, "Cannot store StaticFunction inside itself");
static_assert(sizeof(ModelType) <= storage_size_, "Function does not fit, increase storage size");
static_assert(alignof(FunctionConcept) == alignof(ModelType), "Cannot accommodate alignment requirements");
std::construct_at(reinterpret_cast<ModelType *>(storage()), std::forward<TFunction>(fn));
valid_ = true;
}
void destroy()
{
if (valid_)
{
std::destroy_at(&get());
valid_ = false;
}
}
[[nodiscard]] FunctionConcept const &get() const { return *std::launder(storage()); }
[[nodiscard]] FunctionConcept &get() { return *std::launder(storage()); }
[[nodiscard]] FunctionConcept const *storage() const
{
return reinterpret_cast<FunctionConcept const *>(storage_.data());
}
[[nodiscard]] FunctionConcept *storage() { return reinterpret_cast<FunctionConcept *>(storage_.data()); }
static constexpr auto storage_size_{StorageSize + sizeof(FunctionConcept)};
alignas(FunctionConcept) std::array<std::byte, storage_size_> storage_{};
bool valid_{false};
};
} // namespace util
#endif // STATIC_FUNCTION_HPP
#include <cstdlib>
#include <exception>
#include <functional>
#include <utility>
#include <catch2/catch_test_macros.hpp>
#include <catch_ext/death_test.hpp>
#include <util/static_function.hpp>
using util::StaticFunction;
namespace {
int function(int some_y)
{
return 5 + some_y;
}
class Functor
{
public:
explicit Functor(int some_x) : some_x_{some_x} {}
int operator()(int some_y) const { return some_x_ + some_y; }
private:
int some_x_;
};
auto const lambda = [some_x = 5](int some_y) {
return some_x + some_y;
};
} // anonymous namespace
TEST_CASE("can store all sorts of functions", "[util][StaticFunction]")
{
SECTION("works with regular function")
{
StaticFunction<int(int)> func{&function};
REQUIRE(func(1) == 6);
REQUIRE(std::as_const(func)(1) == 6);
}
SECTION("works with member function")
{
Functor functor{5};
StaticFunction<int(int)> func{
std::bind(&Functor::operator(), &functor, std::placeholders::_1)}; // NOLINT(modernize-avoid-bind)
REQUIRE(func(1) == 6);
}
SECTION("works with functor")
{
StaticFunction<int(int)> func{Functor{5}};
REQUIRE(func(1) == 6);
}
SECTION("works with lambda")
{
StaticFunction<int(int)> func{lambda};
REQUIRE(func(1) == 6);
}
SECTION("works with multiple arguments")
{
int some_x = 5;
StaticFunction<void()> func_1{[] {
}};
StaticFunction<int(int, int)> func_2{[](int some_a, int some_b) {
return some_a + some_b;
}};
StaticFunction<void(int)> func_3{[&some_x](int some_y) {
some_x += some_y;
}};
REQUIRE(!!func_1);
REQUIRE((func_1(), true));
REQUIRE(!!func_2);
REQUIRE(func_2(1, 2) == 3);
REQUIRE(!!func_3);
REQUIRE((func_3(1), some_x) == 6);
}
}
TEST_CASE("can be queried for empty", "[util][StaticFunction]")
{
SECTION("is empty when default constructed")
{
StaticFunction<int(int)> func;
REQUIRE(!func);
REQUIRE(func == nullptr);
REQUIRE(!static_cast<bool>(func));
}
SECTION("is empty when constructed from null pointer")
{
StaticFunction<int(int)> func(nullptr);
REQUIRE(!func);
REQUIRE(func == nullptr);
REQUIRE(!static_cast<bool>(func));
}
SECTION("is not empty when function is stored")
{
StaticFunction<int(int)> func{&function};
REQUIRE(!!func);
REQUIRE(func != nullptr);
REQUIRE(static_cast<bool>(func));
}
}
TEST_CASE("copy and move construction and assignment", "[util][StaticFunction]")
{
SECTION("move construction")
{
StaticFunction<int(int)> func_1{&function};
StaticFunction<int(int)> func_2{std::move(func_1)};
REQUIRE(func_2);
REQUIRE(func_2(1) == 6);
}
SECTION("copy construction")
{
StaticFunction<int(int)> func_1{&function};
StaticFunction<int(int)> func_2{func_1};
REQUIRE(!!func_1);
REQUIRE(func_1(1) == 6);
REQUIRE(!!func_2);
REQUIRE(func_2(1) == 6);
}
SECTION("move assignment")
{
StaticFunction<int(int)> func_1{&function};
StaticFunction<int(int)> func_2;
func_2 = std::move(func_1);
REQUIRE(func_2);
REQUIRE(func_2(1) == 6);
}
SECTION("copy assignment")
{
StaticFunction<int(int)> func_1{&function};
StaticFunction<int(int)> func_2;
func_2 = func_1;
REQUIRE(!!func_1);
REQUIRE(func_1(1) == 6);
REQUIRE(!!func_2);
REQUIRE(func_2(1) == 6);
}
SECTION("assign after move")
{
StaticFunction<int(int)> func_1{&function};
StaticFunction<int(int)> func_2;
func_2 = std::move(func_1);
func_1 = func_2;
REQUIRE(!!func_1);
REQUIRE(func_1(1) == 6);
}
}
TEST_CASE("assign a new function", "[util][StaticFunction]")
{
SECTION("function is cleared when assigning null pointer")
{
StaticFunction<int(int)> func{&function};
func = nullptr;
REQUIRE(!func);
}
SECTION("regular function can be assigned")
{
StaticFunction<int(int)> func;
func = &function;
REQUIRE(!!func);
REQUIRE(func(1) == 6);
}
SECTION("member function can be assigned")
{
StaticFunction<int(int)> func;
Functor functor{5};
func = std::bind(&Functor::operator(), &functor, std::placeholders::_1); // NOLINT(modernize-avoid-bind)
REQUIRE(!!func);
REQUIRE(func(1) == 6);
}
SECTION("functor can be assigned")
{
StaticFunction<int(int)> func;
func = Functor{5};
REQUIRE(!!func);
REQUIRE(func(1) == 6);
}
SECTION("lambda can be assigned")
{
StaticFunction<int(int)> func;
func = lambda;
REQUIRE(!!func);
REQUIRE(func(1) == 6);
}
SECTION("reference wrapped function can be assigned")
{
auto functor = Functor{5};
std::reference_wrapper const wrapped{functor};
StaticFunction<int(int)> func;
func = wrapped;
REQUIRE(!!func);
REQUIRE(func(1) == 6);
}
}
TEST_CASE("invoke function", "[util][StaticFunction]")
{
std::set_terminate(std::abort);
SECTION("invoking empty function violates precondition")
{
REQUIRE_DEATH([] {
StaticFunction<void()> func{};
func();
});
REQUIRE_DEATH([] {
StaticFunction<void()> const func{};
func();
});
}
}
@glennvl
Copy link
Author

glennvl commented Dec 14, 2022

Tested on compiler explorer with

  • options: -std=c++20 -O3 -Wall -Wextra -Wpedantic -Wshadow -Wconversion -Wsign-conversion -Wdouble-promotion -Werror -lCatch2Main -fsanitize=address,undefined
  • libraries: Catch2 3.0.0-preview3
  • gcc 12.2
  • clang 15.0.0
version compiler explorer notes
3 std::aligned_storage is deprecated in C++23, use std:array of std::byte instead
2 https://compiler-explorer.com/z/faraEvz95 use std::construct_at + make operator== hidden friend
1 https://compiler-explorer.com/z/MTP3x3enq initial version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment