Create a gist now

Instantly share code, notes, and snippets.

#include <cstdint>
#include <type_traits>
template<class T>
constexpr std::intptr_t index_of() {
return 1;
}
template<class T, class T0, class... Ts>
constexpr std::intptr_t index_of() {
return std::is_same<T, T0>::value ? 0 : (1 + index_of<T, Ts...>());
}
template<class T, class... Ts>
constexpr std::intptr_t bit_of() {
auto constexpr value = std::integral_constant<std::intptr_t, 1>() << index_of<T, Ts...>();
return value == sizeof...(Ts) ? -1 : value;
}
template<class... Ts>
class UnionPtr {
static_assert(sizeof...(Ts) > 1, "Requires at least 2 types");
static_assert((1u << sizeof...(Ts)) <= alignof(void*), "Maximum supported types exceeded");
static_assert((std::is_pointer<Ts>::value & ...), "Only pointer types");
std::intptr_t ptr_;
template<class T>
static std::intptr_t encode(T* ptr) {
static_assert((1u << sizeof...(Ts)) <= alignof(T*), "Maximum supported types exceeded");
auto constexpr flag = bit_of<T*, Ts...>();
static_assert(flag != -1, "Unregistered type");
std::intptr_t ptr_as_int = reinterpret_cast<std::intptr_t>(ptr);
return ptr_as_int | flag;
}
template<class P>
static P decode(std::intptr_t ptr_as_int) {
static_assert(std::is_pointer<P>::value, "decoding a pointer");
static_assert((1u << sizeof...(Ts)) <= alignof(P), "Maximum supported types exceeded");
auto constexpr flag = bit_of<P, Ts...>();
static_assert(flag != -1, "Unregistered type");
return flag & ptr_as_int ?
reinterpret_cast<P>(ptr_as_int ^ flag) :
nullptr;
}
public:
template<class T>
UnionPtr(T* ptr) : ptr_(encode(ptr)) {}
template<class T>
void set(T* ptr) {
ptr_ = encode(ptr);
}
template<class P>
P get() const {
return decode<P>(ptr_);
}
};
#include <iostream>
int main() {
// force alignment because we get no garentee for stack allocation
alignas(8) int a = 1;
alignas(8) double b = 2.;
alignas(8) bool c = true;
UnionPtr<int*, double*, bool*> U(&a);
std::cout << U.get<int*>() << ": " << *U.get<int*>() << "\n";
std::cout << U.get<double*>() << "\n";
std::cout << U.get<bool*>() << "\n";
U.set(&b);
std::cout << U.get<int*>() << "\n";
std::cout << U.get<double*>() << ": " << *U.get<double*>() << "\n";
std::cout << U.get<bool*>() << "\n";
U.set(&c);
std::cout << U.get<int*>() << "\n";
std::cout << U.get<double*>() << "\n";
std::cout << U.get<bool*>() << ": " << *U.get<bool*>() << "\n";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment