Last active
October 23, 2019 14:32
-
-
Save jbelloncastro/77d0d94e5cbc2f17d3c3007eaeb61444 to your computer and use it in GitHub Desktop.
Deduce function argument type and perform runtime type dispatch based on that type
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <type_traits> | |
#include <llvm/ADT/STLExtras.h> | |
struct Base { virtual ~Base() = default; }; | |
bool traverse( const Base* base, llvm::function_ref<bool(const Base*)> ); | |
struct dispatcher { | |
template <class F, class U, class... Args> | |
auto operator()( F &&f, U &&u, Args&&... args ) const { | |
if constexpr (std::is_invocable<F,Args...>()) { | |
return std::forward<F>(f)(std::forward<Args>(args)...); | |
} else { | |
return std::forward<U>(u)(std::forward<Args>(args)...); | |
} | |
} | |
}; | |
struct Derived : Base {}; | |
// Dynamic dispatch for a known derived type | |
template <class F> | |
bool traverse_known( const Base* base, F &&f ) { | |
return traverse(base, [&] ( const Base *b ) -> bool { | |
// Fallback function if F does not implement callback for Derived | |
auto fallback = [] (auto /*derived*/) -> bool { return true; }; | |
if (auto d = dynamic_cast<const Derived*>(b)) { | |
// Calls f(d) if implemented, otherwise calls fallback(d) | |
return dispatcher()(std::forward<F>(f), fallback, d); | |
} else { | |
// b is not an instance of Derived, call fallback(b) | |
return fallback(b); | |
} | |
}); | |
} | |
namespace traits { | |
template <class T> struct call_arg_impl; | |
// Deduces function argument for a pointer to member function | |
template <class R, class F, class Arg> | |
struct call_arg_impl<R (F::*)(Arg)> { typedef Arg type; }; | |
// Deduces function argument for a pointer to const member function | |
template <class R, class F, class Arg> | |
struct call_arg_impl<R (F::*)(Arg) const> { typedef Arg type; }; | |
template <class F > | |
struct call_arg : call_arg_impl<decltype(&F::operator())> {}; | |
// Deduces function argument for a function pointer | |
template <class R, class Arg> | |
struct call_arg<R (*)(Arg)> { typedef Arg type; }; | |
template <class T> | |
using call_arg_t = typename call_arg<T>::type; | |
} // namespace traits | |
template <class F> | |
bool traverse_deduce( const Base* base, F &&f, bool noMatchReturn = true ) { | |
using Arg = traits::call_arg_t<F>; | |
using Derived = std::remove_pointer_t<Arg>; | |
static_assert( std::is_base_of<Base,Derived>::value, "" ); | |
return traverse(base, [=,f=std::forward<F>(f)] ( const Base *b ) -> bool { | |
if (auto d = dynamic_cast<Arg>(b)) { | |
return f(d); | |
} else { | |
return noMatchReturn; | |
} | |
}); | |
} | |
bool foo (const Base *b) { | |
struct A : Base {}; | |
return traverse_deduce(b, []( const A* ) -> bool { return false; }); | |
} | |
bool f( const Derived * ); | |
bool foo2 (const Base *b) { | |
return traverse_deduce(b, &f); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment