-
-
Save smunix/15c4b6c5a4bb7e917e97b3085e5a2bc7 to your computer and use it in GitHub Desktop.
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
-- | | |
module Arithmetic.Expr where | |
import Data.Char (chr) | |
import GHC.Real (RealFrac (truncate)) | |
{- | |
Expr :: Term | |
Term :: Int | |
| Factor + Factor | |
Factor :: Term | |
| Factor x Factor | |
-} | |
{- | |
data Term | |
= MkPrim Int | |
| MkAdd Factor Factor | |
deriving (Show, Eq) | |
-} | |
data Term where | |
MkPrim :: Int -> Term | |
MkAdd :: Factor -> Factor -> Term | |
deriving (Show, Eq) | |
data TermG a where | |
MkPrimG :: a -> TermG a | |
MkAddG :: FactorG a -> FactorG a -> TermG a | |
deriving (Show, Eq) | |
instance Functor TermG where | |
fmap f (MkPrimG a) = MkPrimG (f a) | |
fmap f (MkAddG l r) = MkAddG (fmap f l) (fmap f r) | |
{- d = MkPrimG 1.0 | |
i = MkPrimG 2 | |
s = MkPrimG "Mo" | |
-} | |
fnL :: forall a. [] a -> Maybe a | |
fnL [] = Nothing | |
fnL (a : as) = Just a | |
fnG :: forall a. (Num a) => TermG a -> a | |
fnG (MkPrimG a) = a | |
fnG (MkAddG l r) = evalF l + evalF r | |
evalF :: forall a. Num a => FactorG a -> a | |
evalF (MkTermG t) = fnG t | |
evalF (MkMulG l r) = evalF l * evalF r | |
{- | |
data Factor | |
= MkTerm Term | |
| MkMul Factor Factor | |
deriving (Show, Eq) | |
-} | |
data Factor where | |
MkTerm :: Term -> Factor | |
MkMul :: Factor -> Factor -> Factor | |
deriving (Show, Eq) | |
data FactorG a where | |
MkTermG :: TermG a -> FactorG a | |
MkMulG :: FactorG a -> FactorG a -> FactorG a | |
deriving (Show, Eq) | |
instance Functor FactorG where | |
fmap f (MkTermG t) = MkTermG (fmap f t) | |
fmap f (MkMulG l r) = MkMulG (fmap f l) (fmap f r) | |
type Expr = Term | |
type ExprG a = TermG a | |
{- e = 1 + 2 x 3 -} | |
{- | |
+ | |
1 x | |
2 3 | |
-} | |
e :: Expr | |
e = MkAdd f1 f2 | |
where | |
f2 :: Factor | |
f2 = MkMul (MkTerm (MkPrim 2)) (MkTerm (MkPrim 3)) | |
f1 :: Factor | |
f1 = MkTerm (MkPrim 1) | |
-- $> import Xiaoming.Arithmetic.Expr | |
-- $> e | |
eg :: forall a. (a ~ Float) => ExprG a | |
eg = MkAddG f1 f2 | |
where | |
f2 :: FactorG a | |
f2 = MkMulG (MkTermG (MkPrimG 2)) (MkTermG (MkPrimG 3)) | |
f1 :: FactorG a | |
f1 = MkTermG (MkPrimG 1) | |
-- $> import Arithmetic.Expr | |
-- $> (fnG eg, eg) | |
-- $> import Data.Char | |
-- $> (fmap (chr) . (fmap truncate) $ eg, eg) | |
-- $> (fmap (chr . truncate) $ eg, eg) |
{- | |
BNF Grammar representation (Backus Nor Form) | |
Expr : Expr + Term | |
| Expr - Term | |
| Term | |
Term : Term * Factor | |
| Term / Factor | |
| Factor | |
Factor : Int | |
-} | |
data Factor where | |
Int :: Int -> Factor | |
deriving (Show) | |
data Term where | |
(:*:) :: Term -> Factor -> Term | |
(:/:) :: Term -> Factor -> Term | |
Factor :: Factor -> Term | |
deriving (Show) | |
data Expr where | |
(:+:) :: Expr -> Term -> Expr | |
(:-:) :: Expr -> Term -> Expr | |
Term :: Term -> Expr | |
deriving (Show) | |
-- 1 + 2 x 3 | |
v :: Expr | |
v = Term (Factor (Int 1)) :+: (Factor (Int 2) :*: Int 3) | |
-- $> v |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
-- | | |
module Regex.Expr where | |
{- | |
'a' | |
'ab' <=> ReConcat (ReChar 'a') (ReChar 'b') <=> 'a' * 'b' | |
'abc' <=> ReConcat (ReChar 'a') (ReConcat (ReChar 'b') (ReChar 'c'))<=> 'a' * 'b' | |
'a|b' <=> 'a' + 'b' | |
'a*' | |
Kleene Star | |
\epsilon | |
-} | |
data RegExpr a where | |
ReChar :: a -> RegExpr a | |
ReEpsilon :: RegExpr a | |
ReConcat :: RegExpr a -> RegExpr a -> RegExpr a | |
ReChoice :: RegExpr a -> RegExpr a -> RegExpr a | |
ReStar :: RegExpr a -> RegExpr a | |
data E a where | |
Var :: a -> E a | |
Fun :: a -> E a -> E a | |
App :: E a -> E a -> E a | |
data Ty a where | |
PrimTy :: a -> Ty a | |
TempTy :: a -> Ty a -> Ty a | |
AppTy :: Ty a -> Ty a -> Ty a |
#### 1 + 2 * 3
(+ (Expr (Term (Factor 1.1))) (* (Term (Factor 2.2)) (Factor 3.3)))
8.36
-- to int type
(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = 41
-- to string type
1.100000 + 2.200000 * 3.300000
#### 1 - 2 * 0
(- (Expr (Term (Factor 1))) (/ (Term (Factor 2)) (Factor 0)))
Error: DivByZero
#include <array>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <string_view>
#include <utility>
#include <variant>
using namespace std::literals;
template <typename... Args> [[nodiscard]] auto format(Args &&...args) -> std::string {
std::ostringstream oss;
(oss << ... << args);
return oss.str();
}
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
struct [[nodiscard]] Error {
enum Type : std::int8_t { NoIdea, DivByZero };
explicit Error(Type t = NoIdea) : value_(t) {}
[[nodiscard]] auto value() const noexcept -> std::string_view { return typeToString(); }
private:
[[nodiscard]] auto typeToString() const -> std::string_view {
static constexpr std::array ts{"Error: I have no idea what's going on here"sv,
"Error: DivByZero"sv};
return ts[static_cast<std::size_t>(value_)];
}
const Type value_ = NoIdea;
};
struct [[nodiscard]] Int {
explicit Int(int i = 0) : value_(i) {}
[[nodiscard]] auto value() const noexcept -> int { return value_; }
private:
const int value_{};
};
struct [[nodiscard]] Double {
explicit Double(double d = 0) : value_(d) {}
[[nodiscard]] auto value() const noexcept -> double { return value_; }
private:
const double value_{};
};
struct [[nodiscard]] String {
explicit String(std::string s = {}) : value_(std::move(s)) {}
[[nodiscard]] auto value() const &noexcept -> const std::string & { return value_; }
[[nodiscard]] auto value() const && -> std::string { return value_; }
private:
const std::string value_;
};
using BasicType = std::variant<Error, Int, Double, String>;
auto operator<<(std::ostream &os, const BasicType &v) -> std::ostream & {
return std::visit([&os](const auto &a) -> std::ostream & { return os << a.value(); }, v);
}
[[nodiscard]] auto operator+(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit(
[](const auto &l, const auto &r) -> BasicType {
using LT = std::decay_t<decltype(l)>;
using RT = std::decay_t<decltype(r)>;
if constexpr (std::is_same_v<LT, RT>) {
if constexpr (std::is_same_v<LT, Int>) {
return Int(l.value() + r.value());
}
if constexpr (std::is_same_v<LT, Double>) {
return Double(l.value() + r.value());
}
if constexpr (std::is_same_v<LT, String>) {
return String(format(l.value(), " + ", r.value()));
}
} else {
if constexpr (std::is_same_v<LT, Error>) {
return l;
}
if constexpr (std::is_same_v<RT, Error>) {
return r;
}
}
return l + r;
},
lhs, rhs);
}
[[nodiscard]] auto operator-(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit(
[](const auto &l, const auto &r) -> BasicType {
using LT = std::decay_t<decltype(l)>;
using RT = std::decay_t<decltype(r)>;
if constexpr (std::is_same_v<LT, RT>) {
if constexpr (std::is_same_v<LT, Int>) {
return Int(l.value() - r.value());
}
if constexpr (std::is_same_v<LT, Double>) {
return Double(l.value() - r.value());
}
if constexpr (std::is_same_v<LT, String>) {
return String(format(l.value(), " - ", r.value()));
}
} else {
if constexpr (std::is_same_v<LT, Error>) {
return l;
}
if constexpr (std::is_same_v<RT, Error>) {
return r;
}
}
return l - r;
},
lhs, rhs);
}
[[nodiscard]] auto operator*(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit(
[](const auto &l, const auto &r) -> BasicType {
using LT = std::decay_t<decltype(l)>;
using RT = std::decay_t<decltype(r)>;
if constexpr (std::is_same_v<LT, RT>) {
if constexpr (std::is_same_v<LT, Int>) {
return Int(l.value() * r.value());
}
if constexpr (std::is_same_v<LT, Double>) {
return Double(l.value() * r.value());
}
if constexpr (std::is_same_v<LT, String>) {
return String(format(l.value(), " * ", r.value()));
}
} else {
if constexpr (std::is_same_v<LT, Error>) {
return l;
}
if constexpr (std::is_same_v<RT, Error>) {
return r;
}
}
return l * r;
},
lhs, rhs);
}
[[nodiscard]] auto operator/(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit(
[](const auto &l, const auto &r) -> BasicType {
using LT = std::decay_t<decltype(l)>;
using RT = std::decay_t<decltype(r)>;
if constexpr (std::is_same_v<LT, RT>) {
if constexpr (std::is_same_v<LT, Int>) {
if (r.value() == 0) {
return Error(Error::Type::DivByZero);
}
return Int(l.value() / r.value());
}
if constexpr (std::is_same_v<LT, Double>) {
return Double(l.value() / r.value());
}
if constexpr (std::is_same_v<LT, String>) {
return String(format(l.value(), " / ", r.value()));
}
} else {
if constexpr (std::is_same_v<LT, Error>) {
return l;
}
if constexpr (std::is_same_v<RT, Error>) {
return r;
}
}
return l * r;
},
lhs, rhs);
}
template <typename T> using Ptr = std::unique_ptr<T>;
template <typename T, typename... Args> [[nodiscard]] auto MkPtr(Args &&...args) -> Ptr<T> {
return std::make_unique<T>(std::forward<Args>(args)...);
}
struct [[nodiscard]] Base {
Base() = default;
Base(const Base &) = delete;
Base(Base &&) = delete;
auto operator=(Base &&) &noexcept -> Base & = default;
auto operator=(const Base &) & -> Base & = default;
virtual ~Base() = default;
using TransformFn = std::function<BasicType(const BasicType &)>;
virtual auto result(TransformFn &&tf = {}) const -> BasicType = 0;
[[nodiscard]] virtual auto toString() const -> std::string = 0;
};
struct [[nodiscard]] Factor : Base {
explicit Factor(BasicType v) : value_(std::move(v)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return fn ? std::forward<TransformFn>(fn)(value_) : value_;
}
[[nodiscard]] auto toString() const -> std::string override {
return format("(Factor ", value_, ")");
}
private:
const BasicType value_;
};
template <typename T> struct [[nodiscard]] BaseOp {
auto result(Base::TransformFn &&fn) const -> BasicType {
return ref()->act(ref()->lhs_->result(std::forward<Base::TransformFn>(fn)),
ref()->rhs_->result(std::forward<Base::TransformFn>(fn)));
}
[[nodiscard]] auto toString() const -> std::string {
return format('(', ref()->opToString(), ' ', ref()->lhs_->toString(), ' ',
ref()->rhs_->toString(), ')');
}
private:
[[nodiscard]] auto ref() const -> const T * { return static_cast<const T *>(this); }
};
struct [[nodiscard]] Term : Base {};
struct [[nodiscard]] TermBinOp : Term, BaseOp<TermBinOp> {
enum class Op : std::int8_t { Mul, Div };
TermBinOp(Op op, Ptr<const Term> lhs, Ptr<const Factor> rhs)
: op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return BaseOp::result(std::forward<TransformFn>(fn));
}
[[nodiscard]] auto toString() const -> std::string override { return BaseOp::toString(); }
private:
[[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
switch (op_) {
case Op::Mul:
return l * r;
case Op::Div:
return l / r;
}
return {}; // shouldn't reach here
}
[[nodiscard]] auto opToString() const -> std::string_view {
static constexpr std::array os{"*"sv, "/"sv};
return os[static_cast<std::size_t>(op_)];
}
friend BaseOp<TermBinOp>;
const Op op_;
const Ptr<const Term> lhs_;
const Ptr<const Factor> rhs_;
};
struct [[nodiscard]] TermFactor : Term {
explicit TermFactor(Ptr<const Factor> v) : factor_(std::move(v)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return factor_->result(std::forward<TransformFn>(fn));
}
[[nodiscard]] auto toString() const -> std::string override {
return format("(Term ", factor_->toString(), ')');
}
private:
const Ptr<const Factor> factor_;
};
struct [[nodiscard]] Expr : Base {};
struct [[nodiscard]] ExprBinOp : Expr, BaseOp<ExprBinOp> {
enum class Op : std::int8_t { Add, Sub };
ExprBinOp(Op op, Ptr<const Expr> lhs, Ptr<const Term> rhs)
: op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return BaseOp::result(std::forward<TransformFn>(fn));
}
[[nodiscard]] auto toString() const -> std::string override { return BaseOp::toString(); }
private:
[[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
switch (op_) {
case Op::Add:
return l + r;
case Op::Sub:
return l - r;
}
return {}; // shouldn't reach here
}
[[nodiscard]] auto opToString() const -> std::string_view {
static constexpr std::array os{"+"sv, "-"sv};
return os[static_cast<std::size_t>(op_)];
}
friend BaseOp<ExprBinOp>;
const Op op_;
const Ptr<const Expr> lhs_;
const Ptr<const Term> rhs_;
};
struct [[nodiscard]] ExprTerm : Expr {
explicit ExprTerm(Ptr<const Term> v) : factor_(std::move(v)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return factor_->result(std::forward<TransformFn>(fn));
}
[[nodiscard]] auto toString() const -> std::string override {
return format("(Expr ", factor_->toString(), ')');
}
private:
const Ptr<const Term> factor_;
};
void happy() {
const auto createAST = []() -> Ptr<const Expr> {
Ptr<const Factor> one = MkPtr<const Factor>(Double(1.1));
Ptr<const Factor> two = MkPtr<const Factor>(Double(2.2));
Ptr<const Factor> thr = MkPtr<const Factor>(Double(3.3));
Ptr<const Term> mul = MkPtr<const TermBinOp>(
TermBinOp::Op::Mul, MkPtr<const TermFactor>(std::move(two)), std::move(thr));
Ptr<const Expr> add = MkPtr<const ExprBinOp>(
ExprBinOp::Op::Add, MkPtr<const ExprTerm>(MkPtr<const TermFactor>(std::move(one))),
std::move(mul));
return add;
};
const auto ast = createAST();
std::cout << "#### 1 + 2 * 3\n";
std::cout << ast->toString() << '\n';
std::cout << ast->result() << '\n';
std::cout << "\n\n-- to int type\n";
std::cout << "(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = ";
std::cout << ast->result([](const BasicType &v) -> BasicType {
const auto d = std::get<Double>(v).value();
return Int(std::floor(d * d));
}) << '\n';
std::cout << "\n\n-- to string type\n";
std::cout << ast->result([](const BasicType &v) -> BasicType {
const auto d = std::get<Double>(v).value();
return String(std::to_string(d));
}) << '\n';
}
void divByZero() {
const auto createAST = []() -> Ptr<const Expr> {
Ptr<const Factor> one = MkPtr<const Factor>(Int(1));
Ptr<const Factor> two = MkPtr<const Factor>(Int(2));
Ptr<const Factor> zer = MkPtr<const Factor>(Int(0));
Ptr<const Term> div = MkPtr<const TermBinOp>(
TermBinOp::Op::Div, MkPtr<const TermFactor>(std::move(two)), std::move(zer));
Ptr<const Expr> add = MkPtr<const ExprBinOp>(
ExprBinOp::Op::Sub, MkPtr<const ExprTerm>(MkPtr<const TermFactor>(std::move(one))),
std::move(div));
return add;
};
const auto ast = createAST();
std::cout << "#### 1 - 2 * 0\n";
std::cout << ast->toString() << '\n';
std::cout << ast->result() << '\n';
}
auto main() -> int {
happy();
std::cout << "\n\n\n";
divByZero();
}
This is great @mo-xiaoming, thank you for sharing. Quick question: how would you handle the possibility of this call to fail?
const auto d = std::get<Double>(v).value();
Ideally, I'd want the error handling logic to not pollute much. I'm curious to know what your thoughts are on error handling.
Btw, I love your result
implementation. Isn't that a sort of fmap
simply? less generic though?
This is great @mo-xiaoming, thank you for sharing. Quick question: how would you handle the possibility of this call to fail?
const auto d = std::get<Double>(v).value();
Ideally, I'd want the error handling logic to not pollute much. I'm curious to know what your thoughts are on error handling.
I was thinking when we create the AST, the underlining type is known to us. Using std::visit
for pattern matching all input types is doable, but I think if we got it wrong, then it is a logic error, should not be handled by program.
And I'm just copying what you've done with haskell, otherwise I might simply throw exceptions at here and the point of div by zero. Because once those things happen, this is an ill formed program and there is no way for us to recover from these errors, might well just fail immediately, at least for the trivial functionality we have now
The error handling in C++ has been quite a hot topic among the community for several years, because nobody like what we have today. Popular ways are all flawed, for current c++ language, exceptions must not be thrown in hot path, normal error returns or errno can be silently ignored (something like LLVM::Error
can solve this problem), std::optional
is conceptually equivalent to Maybe
, but the implementation in std makes it no different compare to nullptr
, and it is not composable
I'm not a big fan of exception, but look at what's happened in Go when return errors are not composable, nested if
s make code hard to read
Beside Herb Sutters' effort to eliminate overhead of exceptions Zero-overhead deterministic exceptions: Throwing values.
There is a hope,
There is a paper, the name might interest you, it is called P0798R0: Monadic operations for std::optional, and there is an implementation on github makes optional
much like a real Maybe
Btw, I love your result implementation. Isn't that a sort of fmap simply? less generic though?
Yes, it is sort of fmap
, and yes, sadly it is much less generic.
Because template member functions cannot be virtual, that is, virtual template <typename Ret> Ret result()
is disallowed, then some generic type, std::any
, std::variant
or something similar has to be used as input/output to eliminate the usage of template for virtual functions.
The consequence is, if std::variant
is used, then there going to be some large std::visit
or many nested if constexpr
, sigh. They're ugly
They are talking about adding pattern matching to C++, but there is no time table, I'm afraid...
Cool stuff!
Fyi, the BNF is a little too verbose although it serves well its purpose (of describing unambiguous grammars). The Expr
data type can be simplified as below:
data Exp
= Exp :+: Exp
| Exp :-: Exp
| Exp :*: Exp
| Exp :/: Exp
| Int i
Hopefully this simplifies your C++ implementation as well.
UPDATE I was wrong, I didn't consider if you were using Expr
for everything, then haskell implementation would've been less than 10 lines, so now my C++ code is 30x longer than yours
#### 1 + 2 * 3
(+ (Prim 1.1) (* (Prim 2.2) (Prim 3.3)))
8.36
-- to int type
(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = 41
#### 1 - 2 / 0
(- (Prim 1) (/ (Prim 2) (Prim 0)))
Error: DivByZero: 2/0
Using Expr
to replace Term
and Factor
indeed makes code a little bit shorter, and removed String
type, because it doesn't make much sense at this moment
The current code is slightly shorter than 20x compare to your code, sigh, c++ is too verbose
#include <array>
#include <cmath>
#include <functional>
#include <iostream>
#include <memory>
#include <sstream>
#include <string_view>
#include <utility>
#include <variant>
using namespace std::literals;
template <typename... Args> [[nodiscard]] auto format(Args &&...args) -> std::string {
std::ostringstream oss;
(oss << ... << args);
return oss.str();
}
struct [[nodiscard]] Error {
enum Type : std::int8_t { NoIdea, DivByZero };
explicit Error(Type t = NoIdea) : value_(t) {}
explicit Error(Type t, std::string msg) : value_(t), msg_(std::move(msg)) {}
[[nodiscard]] auto value() const noexcept -> std::string {
return format(typeToString(), ": ", msg_);
}
private:
[[nodiscard]] auto typeToString() const -> std::string_view {
static constexpr std::array ts{"Error: I have no idea what's going on here"sv,
"Error: DivByZero"sv};
return ts[static_cast<std::size_t>(value_)];
}
const Type value_ = NoIdea;
const std::string msg_;
};
struct [[nodiscard]] Int {
explicit Int(int i = 0) : value_(i) {}
[[nodiscard]] auto value() const noexcept -> int { return value_; }
private:
const int value_{};
};
struct [[nodiscard]] Double {
explicit Double(double d = 0) : value_(d) {}
[[nodiscard]] auto value() const noexcept -> double { return value_; }
private:
const double value_{};
};
using BasicType = std::variant<Error, Int, Double>;
auto operator<<(std::ostream &os, const BasicType &v) -> std::ostream & {
return std::visit([&os](const auto &a) -> std::ostream & { return os << a.value(); }, v);
}
// cannot make them hidden friends, because BasicType is unknown inside these classes
[[nodiscard]] auto operator+(Int lhs, Int rhs) -> BasicType {
return Int(lhs.value() + rhs.value());
}
[[nodiscard]] auto operator-(Int lhs, Int rhs) -> BasicType {
return Int(lhs.value() - rhs.value());
}
[[nodiscard]] auto operator*(Int lhs, Int rhs) -> BasicType {
return Int(lhs.value() * rhs.value());
}
[[nodiscard]] auto operator/(Int lhs, Int rhs) -> BasicType {
if (rhs.value() == 0) {
return Error(Error::Type::DivByZero, format(lhs.value(), '/', rhs.value()));
}
return Int(lhs.value() / rhs.value());
}
[[nodiscard]] auto operator+(Double lhs, Double rhs) -> BasicType {
return Double(lhs.value() + rhs.value());
}
[[nodiscard]] auto operator-(Double lhs, Double rhs) -> BasicType {
return Double(lhs.value() - rhs.value());
}
[[nodiscard]] auto operator*(Double lhs, Double rhs) -> BasicType {
return Double(lhs.value() * rhs.value());
}
[[nodiscard]] auto operator/(Double lhs, Double rhs) -> BasicType {
return Double(lhs.value() / rhs.value());
}
template <typename T>
[[nodiscard]] auto operator+(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator-(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator*(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
return Error(Error::NoIdea, "two errors?");
}
template <typename T>
[[nodiscard]] auto operator/(const Error & /*lhs*/, const Error & /*rhs*/) -> BasicType {
return Error(Error::NoIdea, "two errors?");
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator+(Error lhs, const T & /*rhs*/) -> BasicType {
return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator-(Error lhs, const T & /*rhs*/) -> BasicType {
return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator*(Error lhs, const T & /*rhs*/) -> BasicType {
return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator/(Error lhs, const T & /*rhs*/) -> BasicType {
return lhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator+(const T & /*lhs*/, Error rhs) -> BasicType {
return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator-(const T & /*lhs*/, Error rhs) -> BasicType {
return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator*(const T & /*lhs*/, Error rhs) -> BasicType {
return rhs;
}
template <typename T, typename = std::enable_if_t<!std::is_same_v<std::decay_t<T>, Error>>>
[[nodiscard]] auto operator/(const T & /*lhs*/, Error rhs) -> BasicType {
return rhs;
}
[[nodiscard]] auto operator+(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit([](const auto &l, const auto &r) -> BasicType { return l + r; }, lhs, rhs);
}
[[nodiscard]] auto operator-(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit([](const auto &l, const auto &r) -> BasicType { return l - r; }, lhs, rhs);
}
[[nodiscard]] auto operator*(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit([](const auto &l, const auto &r) -> BasicType { return l * r; }, lhs, rhs);
}
[[nodiscard]] auto operator/(const BasicType &lhs, const BasicType &rhs) -> BasicType {
return std::visit([](const auto &l, const auto &r) -> BasicType { return l / r; }, lhs, rhs);
}
template <typename T> using Ptr = std::unique_ptr<T>;
template <typename T, typename... Args> [[nodiscard]] auto MkPtr(Args &&...args) -> Ptr<T> {
return std::make_unique<T>(std::forward<Args>(args)...);
}
struct [[nodiscard]] Expr {
Expr() = default;
Expr(const Expr &) = delete;
Expr(Expr &&) = delete;
auto operator=(Expr &&) &noexcept -> Expr & = default;
auto operator=(const Expr &) & -> Expr & = default;
virtual ~Expr() = default;
using TransformFn = std::function<BasicType(const BasicType &)>;
virtual auto result(TransformFn &&tf = {}) const -> BasicType = 0;
[[nodiscard]] virtual auto toString() const -> std::string = 0;
};
struct [[nodiscard]] Prim : Expr {
explicit Prim(BasicType v) : value_(std::move(v)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return fn ? std::forward<TransformFn>(fn)(value_) : value_;
}
[[nodiscard]] auto toString() const -> std::string override {
return format("(Prim ", value_, ")");
}
private:
const BasicType value_;
};
struct [[nodiscard]] ExprBinOp : Expr {
enum class Op : std::int8_t { Mul, Div, Add, Sub };
ExprBinOp(Op op, Ptr<const Expr> lhs, Ptr<const Expr> rhs)
: op_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}
auto result(TransformFn &&fn) const -> BasicType override {
return act(lhs_->result(std::forward<TransformFn>(fn)),
rhs_->result(std::forward<TransformFn>(fn)));
}
[[nodiscard]] auto toString() const -> std::string override {
return format('(', opToString(), ' ', lhs_->toString(), ' ', rhs_->toString(), ')');
}
private:
[[nodiscard]] auto act(const BasicType &l, const BasicType &r) const -> BasicType {
switch (op_) {
case Op::Mul:
return l * r;
case Op::Div:
return l / r;
case Op::Add:
return l + r;
case Op::Sub:
return l - r;
}
return {}; // shouldn't reach here
}
[[nodiscard]] auto opToString() const -> std::string_view {
static constexpr std::array os{"*"sv, "/"sv, "+"sv, "-"sv};
return os[static_cast<std::size_t>(op_)];
}
const Op op_;
const Ptr<const Expr> lhs_;
const Ptr<const Expr> rhs_;
};
void happy() {
const auto createAST = []() -> Ptr<const Expr> {
Ptr<const Expr> one = MkPtr<const Prim>(Double(1.1));
Ptr<const Expr> two = MkPtr<const Prim>(Double(2.2));
Ptr<const Expr> thr = MkPtr<const Prim>(Double(3.3));
Ptr<const Expr> mul =
MkPtr<const ExprBinOp>(ExprBinOp::Op::Mul, std::move(two), std::move(thr));
Ptr<const Expr> add =
MkPtr<const ExprBinOp>(ExprBinOp::Op::Add, std::move(one), std::move(mul));
return add;
};
const auto ast = createAST();
std::cout << "#### 1 + 2 * 3\n";
std::cout << ast->toString() << '\n';
std::cout << ast->result() << '\n';
std::cout << "\n\n-- to int type\n";
std::cout << "(1.1^2)::int + (2.2^2)::int * (3.3^2)::int = ";
std::cout << ast->result([](const BasicType &v) -> BasicType {
const auto d = std::get<Double>(v).value();
return Int(std::floor(d * d));
}) << '\n';
}
void divByZero() {
const auto createAST = []() -> Ptr<const Expr> {
Ptr<const Expr> one = MkPtr<const Prim>(Int(1));
Ptr<const Expr> two = MkPtr<const Prim>(Int(2));
Ptr<const Expr> zer = MkPtr<const Prim>(Int(0));
Ptr<const Expr> div =
MkPtr<const ExprBinOp>(ExprBinOp::Op::Div, std::move(two), std::move(zer));
Ptr<const Expr> sub =
MkPtr<const ExprBinOp>(ExprBinOp::Op::Sub, std::move(one), std::move(div));
return sub;
};
const auto ast = createAST();
std::cout << "#### 1 - 2 / 0\n";
std::cout << ast->toString() << '\n';
std::cout << ast->result() << '\n';
}
auto main() -> int {
happy();
std::cout << "\n\n\n";
divByZero();
}
I think the following code should have the similar functionality, but much verbose, due to following reasons
fmap
is a pain in the a**. Have to use some generic types, likestd::any
orstd::variant