Skip to content

Instantly share code, notes, and snippets.

@mtao
Last active October 8, 2023 23:50
Show Gist options
  • Save mtao/567130fa412eb5ef13fdc599d16e6d8f to your computer and use it in GitHub Desktop.
Save mtao/567130fa412eb5ef13fdc599d16e6d8f to your computer and use it in GitHub Desktop.
Generic caching of multiple return types when switching between compile time and runtime polymorphism
#include <spdlog/spdlog.h>
#include <functional>
#include <map>
#include <type_traits>
#include <utility>
#include <variant>
// Say you have a few heavy data structures derived from a single base class
// that has a value/enum that can be used to identify which derived class you have.
// (a simplicial complex class with different derived classes for dimensions 2,3,4)
// Furthermore say you have to call some functor on each one, but the functor returns
// a different type for each derived class.
// (an operation like edge split/collapse)
// Finally say you need to cache all of these values so someone can replay the results
// This cpp contains a generic-ish means of doing this mess.
// (user wants to identify the "reference vertex" of the op in a particular SC to update an attribute on it)
// Re-implement unwrap ref in case it doesn't exist with our current compiler
// Implementation is the Possible Implementation from:
// https://en.cppreference.com/w/cpp/utility/functional/unwrap_reference
#if !defined(__cpp_lib_unwrap_ref)
namespace std {
template <class T>
struct unwrap_reference {
using type = T;
};
template <class U>
struct unwrap_reference<std::reference_wrapper<U>> {
using type = U&;
};
template <class T>
struct unwrap_ref_decay : std::unwrap_reference<std::decay_t<T>> {};
template <class T>
using unwrap_ref_decay_t = typename unwrap_ref_decay<T>::type;
} // namespace std
#endif
// Declare a base type that has some sort of ID member and a way of identifying
// an appropriate derived type
struct Input {
int type = -1;
int id;
};
// Some example derived types
struct A : public Input {
A(int id) : Input{0, id} {}
};
struct B : public Input {
B(int id) : Input{1, id} {}
};
struct C : public Input {
C(int id) : Input{2, id} {}
};
// My target application's "Input" class is quite heavy and the Input objects
// persist for long periods of time relative to what this is being used for, so
// I want to use a variant of references rather than values
//
// Here's a helper definition for making variants of references
template <typename... T>
using ReferenceWrapperVariant = std::variant<std::reference_wrapper<T>...>;
// The reference class for this type
using InputVariant = ReferenceWrapperVariant<A, B, C>;
InputVariant as_variant(Input& value) {
switch (value.type) {
case 0:
return std::reference_wrapper(static_cast<A&>(value));
case 1:
return std::reference_wrapper(static_cast<B&>(value));
case 2:
return std::reference_wrapper(static_cast<C&>(value));
default:
throw "InvalidInput";
}
// This should never happen, just making a dummy to suppress warnings
return std::reference_wrapper(reinterpret_cast<A&>(value));
}
// A helper class for specifying per-type return types from an input functor
// Assumes the argument is the variant type being selected form, all other
// arguments are passed in as const references
template <typename Functor, typename... Ts>
struct ReturnVariantHelper {};
template <typename Functor, typename... VTs, typename... Ts>
struct ReturnVariantHelper<Functor, std::variant<VTs...>, Ts...> {
// For a specific type in the variant, get the return type
template <typename T>
using ReturnType =
std::decay_t<std::invoke_result_t<Functor, std::unwrap_ref_decay_t<T>&,
const Ts&...>>;
// Get an overall variant for the types
using type = std::variant<ReturnType<VTs>...>;
};
// Interface for reading off the return values from data
template <typename Functor, typename... OtherArgumentTypes>
class ReturnDataStore {
public:
using TypeHelper =
ReturnVariantHelper<Functor, InputVariant, OtherArgumentTypes...>;
using ReturnVariant = typename TypeHelper::type;
// a pointer to an input and some other arguments
using KeyType = std::tuple<const Input*, OtherArgumentTypes...>;
auto get_id(const Input& input, const OtherArgumentTypes&... ts) const {
// other applications might use a fancier version of get_id
return KeyType(&input, ts...);
}
// Add new data by giving the InputType
// InputType is used to make sure the pair of Input/Output is valid and to
// extract an id
template <typename InputType, typename ReturnType>
void add(const InputType& input, ReturnType&& return_data,
const OtherArgumentTypes&... args) {
using ReturnType_t = std::decay_t<ReturnType>;
static_assert(!std::is_same_v<std::decay_t<InputType>, Input>,
"Don't pass in a input, use variant/visitor to get its "
"derived type");
// if the user passed in a input class lets try re-invoking with a
// derived type
auto id = get_id(input, args...);
using ExpectedReturnType =
typename TypeHelper::template ReturnType<InputType>;
static_assert(std::is_convertible_v<ReturnType_t, ExpectedReturnType>,
"Second argument should be the return value of a Functor "
"(or convertible at "
"least) ");
m_data.emplace(id,
ReturnVariant(std::in_place_type_t<ExpectedReturnType>{},
std::forward<ReturnType>(return_data)));
}
// let user get the variant for a specific Input derivate
const auto& get_variant(const Input& input,
const OtherArgumentTypes&... ts) const {
auto id = get_id(input, ts...);
return m_data.at(id);
}
// get the type specific input
template <typename InputType>
auto get(const InputType& input, const OtherArgumentTypes&... ts) const {
static_assert(!std::is_same_v<std::decay_t<InputType>, Input>,
"Don't pass in a input, use variant/visitor to get its "
"derived type");
using ExpectedReturnType =
typename TypeHelper::template ReturnType<InputType>;
return std::get<ExpectedReturnType>(get_variant(input, ts...));
}
private:
std::map<KeyType, ReturnVariant> m_data;
};
template <typename Functor, typename... OtherTypes>
class Runner {
public:
Runner(Functor&& f) : func(f) {}
Runner(Functor&& f, std::tuple<OtherTypes...>) : func(f) {}
void run(Input& input, const OtherTypes&... ts) {
const int id = input.id;
auto var = as_variant(input);
std::visit(
[&](auto& t) {
auto& v = t.get();
return_data.add(v, func(v, ts...), ts...);
},
var);
}
ReturnDataStore<Functor, OtherTypes...> return_data;
private:
const Functor& func;
};
template <typename Functor>
ReturnDataStore(Functor&& f) -> ReturnDataStore<std::decay_t<Functor>>;
template <typename Functor, typename... Ts>
Runner(Functor&& f, std::tuple<Ts...>) -> Runner<Functor, std::decay_t<Ts>...>;
template <typename Functor>
Runner(Functor&& f) -> Runner<Functor>;
struct TestFunctor {
template <typename T>
auto operator()(T& input) const {
using TT = std::unwrap_ref_decay_t<T>;
return std::tuple<TT, int>(input, input.id);
};
};
struct TestFunctor2Args {
template <typename T>
auto operator()(T& input, int data) const {
using TT = std::unwrap_ref_decay_t<T>;
return std::tuple<TT, int>(input, input.id * data);
};
};
int main(int argc, char* argv[]) {
A a(0);
B b(2);
C c(4);
// test calling the functor once
{
auto [ap, i] = TestFunctor{}(a);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
// create a mono arg
Runner r(TestFunctor{});
r.run(a);
r.run(b);
r.run(c);
{
auto [ap, i] = r.return_data.get(a);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
{
auto [ap, i] = r.return_data.get(b);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
{
auto [ap, i] = r.return_data.get(c);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
// try using 2 args
Runner r2(TestFunctor2Args{}, std::tuple<int>{});
r2.run(a, 3);
r2.run(b, 5);
r2.run(c, 7);
{
auto [ap, i] = r2.return_data.get(a, 3);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
{
auto [ap, i] = r2.return_data.get(b, 5);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
{
auto [ap, i] = r2.return_data.get(c, 7);
spdlog::info("{},{} = {}", ap.type, ap.id, i);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment