Skip to content

Instantly share code, notes, and snippets.

@Bananattack
Last active February 27, 2016 22:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Bananattack/1e2d3bbbf80f9ab63779 to your computer and use it in GitHub Desktop.
Save Bananattack/1e2d3bbbf80f9ab63779 to your computer and use it in GitHub Desktop.
a C++14 single-header variant type. Allows defining type-safe tagged unions with pattern-matching/visiting. probably has some implementation mistakes. Might use this for wiz, my high-level assembly language project.
#ifndef WIZ_VARIANT_H
#define WIZ_VARIANT_H
#include <cstddef>
#include <type_traits>
#include <utility>
namespace wiz {
template<std::size_t... Ns>
struct max_value;
template<std::size_t N, std::size_t... Ns>
struct max_value<N, Ns...> {
static const std::size_t value = max_value<Ns...>::value > N
? max_value<Ns...>::value
: N;
};
template<>
struct max_value<> {
static const std::size_t value = 0;
};
template<typename U, typename... Ts>
struct type_tag;
template<typename U, typename T, typename... Ts>
struct type_tag<U, T, Ts...> {
static const int value = type_tag<U, Ts...>::value >= 0
? type_tag<U, Ts...>::value + 1
: -1;
};
template<typename U, typename... Ts>
struct type_tag<U, U, Ts...> {
static const int value = 0;
};
template<typename U>
struct type_tag<U> {
static const int value = -1;
};
template<int N, typename... Ts>
struct tag_dispatcher_;
template<int N, typename T, typename... Ts>
struct tag_dispatcher_<N, T, Ts...> {
static void copy(int tag, void* dest, const void* src) {
if(tag == N) {
new (dest) T(*reinterpret_cast<const T*>(src));
} else {
tag_dispatcher_<N + 1, Ts...>::copy(tag, dest, src);
}
}
static void move(int tag, void* dest, void* src) {
if(tag == N) {
new (dest) T(std::move(*reinterpret_cast<T*>(src)));
} else {
tag_dispatcher_<N + 1, Ts...>::move(tag, dest, src);
}
}
static void destroy(int tag, void* data) {
if(tag == N) {
reinterpret_cast<T*>(data)->~T();
} else {
tag_dispatcher_<N + 1, Ts...>::destroy(tag, data);
}
}
template<typename R, typename F>
static R apply(int tag, F&& f, const void* data) {
if(tag == N) {
return std::forward<F>(f)(*reinterpret_cast<const T*>(data));
} else {
return tag_dispatcher_<N + 1, Ts...>::template apply<R, F>(tag, std::forward<F>(f), data);
}
}
};
template<int N>
struct tag_dispatcher_<N> {
static void copy(int tag, void* dest, const void* src) {}
static void move(int tag, void* dest, void* src) {}
static void destroy(int tag, void* data) {}
template<typename R, typename F>
static R apply(int tag, F&& f, const void* data) {
return R();
}
};
template<typename... Ts>
using tag_dispatcher = tag_dispatcher_<0, Ts...>;
template<typename... Fs>
struct overload;
template <typename F>
struct overload<F> {
public:
overload(F&& f) : f(std::forward<F>(f)) {}
template<typename... Ts>
auto operator()(Ts&&... args) const
-> decltype(std::declval<F>()(std::forward<Ts>(args)...)) {
return f(std::forward<Ts>(args)...);
}
private:
F f;
};
template <typename F, typename... Fs>
struct overload<F, Fs...> : overload<F>, overload<Fs...> {
using overload<F>::operator();
using overload<Fs...>::operator();
overload(F&& f, Fs&&... fs) :
overload<F>(std::forward<F>(f)),
overload<Fs...>(std::forward<Fs>(fs)...) {}
};
template<typename... Ts>
class variant {
public:
variant() = delete;
template<typename U>
variant(const U& value)
: tag(type_tag<U, Ts...>::value) {
static_assert_valid_type<U>();
new(&data) U(value);
}
variant(const variant& other) : tag(other.tag) {
tag_dispatcher<Ts...>::copy(tag, &data, &other.data);
}
variant(variant&& other) : tag(other.tag) {
tag_dispatcher<Ts...>::move(tag, &data, &other.data);
}
~variant() {
tag_dispatcher<Ts...>::destroy(tag, &data);
}
variant& operator =(const variant& other) {
tag_dispatcher<Ts...>::destroy(tag, &data);
tag = other.tag;
tag_dispatcher<Ts...>::copy(tag, &data, &other.data);
return *this;
}
variant& operator =(variant&& other) {
tag_dispatcher<Ts...>::destroy(tag, &data);
tag = other.tag;
tag_dispatcher<Ts...>::move(tag, &data, &other.data);
return *this;
}
int which() const { return tag; }
template<typename U>
bool is() const {
static_assert_valid_type<U>();
return type_tag<U, Ts...>::value == tag;
}
template<typename U>
U& get() {
static_assert_valid_type<U>();
return *reinterpret_cast<U*>(&data);
}
template<typename U>
const U& get() const {
static_assert_valid_type<U>();
return *reinterpret_cast<const U*>(&data);
}
template<typename R, typename F>
R apply(F&& f) const {
return tag_dispatcher<Ts...>::template apply<R, F>(tag, std::forward<F>(f), &data);
}
template<typename R, typename F, typename... Fs>
R apply(F&& f, Fs&&... fs) const {
using overload_type = overload<F, Fs...>;
return tag_dispatcher<Ts...>::template apply<R, overload_type>(
tag,
std::forward<overload_type>(
overload_type(std::forward<F>(f), std::forward<Fs>(fs)...)),
&data);
}
private:
template<typename U>
static void static_assert_valid_type() {
static_assert(type_tag<U, Ts...>::value >= 0, "variant does not support the provided type.");
}
using data_type = typename std::aligned_storage<
max_value<sizeof(Ts)...>::value,
max_value<alignof(Ts)...>::value>::type;
int tag;
data_type data;
};
}
#endif
// A small example program, using the variant for representing a simple expression tree.
#include <memory>
#include <iostream>
#include <wiz/variant.h>
struct binary_operator_expression;
struct number_expression;
typedef wiz::variant<binary_operator_expression, number_expression> expression_variant;
struct binary_operator_expression {
enum class operation_type {
add,
subtract,
multiply,
divide,
modulo
};
binary_operator_expression(operation_type operation, const std::shared_ptr<expression_variant>& left, const std::shared_ptr<expression_variant>& right)
: operation(operation), left(left), right(right) {}
operation_type operation;
std::shared_ptr<expression_variant> left;
std::shared_ptr<expression_variant> right;
};
struct number_expression {
number_expression(std::size_t value)
: value(value) {}
std::size_t value;
};
void dump(std::ostream& out, expression_variant& expr) {
expr.apply<void>(
[&](const binary_operator_expression& expr) {
out << "(";
dump(out, *expr.left);
out << " ";
switch(expr.operation) {
case binary_operator_expression::operation_type::add: out << "+"; break;
case binary_operator_expression::operation_type::subtract: out << "-"; break;
case binary_operator_expression::operation_type::multiply: out << "*"; break;
case binary_operator_expression::operation_type::divide: out << "/"; break;
case binary_operator_expression::operation_type::modulo: out << "%"; break;
default: out << "???"; break;
}
out << " ";
dump(out, *expr.right);
out << ")";
},
[&](const number_expression& expr) {
out << expr.value;
});
}
int evaluate(expression_variant& expr) {
return expr.apply<int>(
[](const binary_operator_expression& expr) {
int left = evaluate(*expr.left);
int right = evaluate(*expr.right);
switch(expr.operation) {
case binary_operator_expression::operation_type::add: return left + right;
case binary_operator_expression::operation_type::subtract: return left - right;
case binary_operator_expression::operation_type::multiply: return left * right;
case binary_operator_expression::operation_type::divide: return left / right;
case binary_operator_expression::operation_type::modulo: return left % right;
default: return 0;
}
},
[](const number_expression& expr) {
return expr.value;
});
}
int main() {
auto expr = expression_variant(binary_operator_expression(
binary_operator_expression::operation_type::subtract,
std::make_shared<expression_variant>(binary_operator_expression(
binary_operator_expression::operation_type::add,
std::make_shared<expression_variant>(number_expression(2)),
std::make_shared<expression_variant>(
binary_operator_expression(
binary_operator_expression::operation_type::multiply,
std::make_shared<expression_variant>(number_expression(240)),
std::make_shared<expression_variant>(number_expression(90)))))),
std::make_shared<expression_variant>(number_expression(4))));
std::cout << "input: ";
dump(std::cout, expr);
std::cout << std::endl;
std::cout << "output: " << evaluate(expr) << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment