Skip to content

Instantly share code, notes, and snippets.

@Sam-Belliveau
Last active October 10, 2021 20:02
Show Gist options
  • Save Sam-Belliveau/3c90f0f05368f0e5dbb0c9a0b37e1025 to your computer and use it in GitHub Desktop.
Save Sam-Belliveau/3c90f0f05368f0e5dbb0c9a0b37e1025 to your computer and use it in GitHub Desktop.
A Infix Notation Evaluator
/**
* MIT License
*
* Copyright (c) Sep 2019, Samuel Robert Belliveau
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef SAM_BELLIVEAU_MATH_EVAL_HEADER_HPP
#define SAM_BELLIVEAU_MATH_EVAL_HEADER_HPP 1
#include <functional> // std::function <- holding operators
#include <exception> // Exceptions to throw
#include <cstdlib> // std::strtold <- parsing floats
#include <sstream> // std::stringstream <- error messages
#include <cstdint> // Various ints
#include <vector> // std::vector <- storing variables
#include <cctype> // std::isdigit
#include <stack> // std::stack <- stack based calculations
#include <cmath> // Math Equations
namespace eval
{
using Letter = char;
using String = const Letter*;
using Size = std::size_t;
// Stack Type
template<class Type>
using StackBaseBackend = std::vector<Type>;
template<class Type>
using StackBase = std::stack<Type, StackBaseBackend<Type>>;
// Used for Conditionals
template<class Number>
static bool NumberToBool(const Number& n)
{ return (static_cast<Number>(0) < n); }
template<class Number>
static Number BoolToNumber(bool cond)
{ return cond ? static_cast<Number>(1.0) : static_cast<Number>(-1.0); }
// Evaluation State
template<class Number>
class State
{
public: // Token Classes
/*************************/
/*** OPERATOR ORDERING ***/
/*************************/
enum class BlockLevel : std::uint8_t
{
Base = 0x00,
Function = 0x01,
Operator = 0x0f,
Bracket = 0xfe,
Special = 0xff
};
/************************/
/*** BASE INSTRUCTION ***/
/************************/
class Token {
private: // Data Members
static constexpr Size ID_SIZE = 32;
Letter id_[ID_SIZE];
Size id_size_;
public: // Overloads
// Used to tell if something should block
virtual BlockLevel getBlockLevel() const noexcept { return BlockLevel::Base; }
// Uses number stack to evaluate operation
virtual void evaluate(State& state) const {}
// Adds operator to state
virtual void add(State& state) const { state.t_stack.push(this); }
public: // Helper Functions
bool checkID(String test_ptr) const noexcept
{
for (Size i = 0; i < id_size_; ++i)
if(test_ptr[i] != id_[i])
return false;
return true;
}
Size getIDLength() const noexcept { return id_size_; }
String getID() const noexcept { return id_; }
public: // Constructor
Token(String id)
{
Size i = 0;
// Copy ID
for(; i < (ID_SIZE - 1) && id[i] != '\0'; ++i)
id_[i] = id[i];
// set ID size
id_size_ = i;
// zero out rest of string
for(; i < ID_SIZE; ++i) id_[i] = '\0';
}
};
/**********************/
/*** UNARY FUNCTION ***/
/**********************/
class UnaryFunction : public Token
{
public:
using FunctionT = std::function<Number(const Number&)>;
private:
FunctionT func_;
public:
virtual BlockLevel getBlockLevel() const noexcept { return BlockLevel::Function; }
virtual void evaluate(State& state) const
{
if(state.n_stack.size() < 1)
{
std::stringstream error;
error << "Syntax Error! [";
error << this->getID() << " -> Expected 1 Number, Recieved 0]";
throw std::runtime_error(error.str());
}
const Number a = state.n_stack.top(); state.n_stack.pop();
state.n_stack.push(func_(a));
}
public:
UnaryFunction(String id, FunctionT func) : Token(id), func_{func} {}
};
/***********************/
/*** BINARY FUNCTION ***/
/***********************/
class BinaryFunction : public Token
{
public:
using FunctionT = std::function<Number(const Number&, const Number&)>;
private:
FunctionT func_;
public:
virtual BlockLevel getBlockLevel() const noexcept
{ return BlockLevel::Function; }
virtual void evaluate(State& state) const
{
if(state.n_stack.size() < 2)
{
std::stringstream error;
error << "Syntax Error! [";
error << this->getID() << " -> Expected 2 Numbers, Recieved ";
error << state.n_stack.size() << "]";
throw std::runtime_error(error.str());
}
const Number b = state.n_stack.top(); state.n_stack.pop();
const Number a = state.n_stack.top(); state.n_stack.pop();
state.n_stack.push(func_(a, b));
}
public:
BinaryFunction(String id, FunctionT func) : Token(id), func_{func} {}
};
/*********************/
/*** OPERATOR TYPE ***/
/*********************/
class Operator : public BinaryFunction
{
public:
using Precedence = std::uint32_t;
enum class Order { Left, Right };
private:
Precedence prec_;
Order order_;
public:
virtual BlockLevel getBlockLevel() const noexcept
{ return BlockLevel::Operator; }
bool isBlocking(const Token& other) const noexcept
{
if (other.getBlockLevel() == BlockLevel::Operator)
{
const Operator* other_op = static_cast<const Operator*>(&other);
if (prec_ < other_op->prec_) {
return true;
} else if (prec_ == other_op->prec_) {
return other_op->order_ == Order::Left;
} else {
return false;
}
}
return other.getBlockLevel() < getBlockLevel();
}
virtual void add(State& state) const
{
// Evaluate operators with higher precedence
while(!state.t_stack.empty() && isBlocking(*state.t_stack.top()))
{
state.t_stack.top()->evaluate(state);
state.t_stack.pop();
}
state.t_stack.push(this);
}
public:
Operator(String id, typename BinaryFunction::FunctionT func, Precedence prec, Order order = Order::Left)
: BinaryFunction(id, func), prec_{prec}, order_{order} {}
};
/************************/
/*** Trinary FUNCTION ***/
/************************/
class TrinaryFunction : public Token
{
public:
using FunctionT = std::function<Number(const Number&, const Number&, const Number&)>;
private:
FunctionT func_;
public:
virtual BlockLevel getBlockLevel() const noexcept
{ return BlockLevel::Function; }
virtual void evaluate(State& state) const
{
if(state.n_stack.size() < 3)
{
std::stringstream error;
error << "Syntax Error! [";
error << this->getID() << " -> Expected 3 Numbers, Recieved ";
error << state.n_stack.size() << "]";
throw std::runtime_error(error.str());
}
const Number c = state.n_stack.top(); state.n_stack.pop();
const Number b = state.n_stack.top(); state.n_stack.pop();
const Number a = state.n_stack.top(); state.n_stack.pop();
state.n_stack.push(func_(a, b, c));
}
public:
TrinaryFunction(String id, FunctionT func) : Token(id), func_{func} {}
};
/*********************/
/*** VARIABLE TYPE ***/
/*********************/
class Variable : public Token
{
private:
Number value_;
public:
virtual BlockLevel getBlockLevel() const noexcept { return BlockLevel::Special; }
virtual void evaluate(State& state) const { state.n_stack.push(value_); }
virtual void add(State& state) const { evaluate(state); }
public:
Variable(String id, const Number& value) : Token(id), value_{value} {}
void setValue(const Number& value) { value_ = value; }
};
/********************/
/*** BRACKET TYPE ***/
/********************/
class Bracket : public Token
{
public:
enum class Type { Left, Right };
private:
Type type_;
public:
virtual BlockLevel getBlockLevel() const noexcept
{ return BlockLevel::Bracket; }
virtual void add(State& state) const
{
if(type_ == Type::Left) { state.t_stack.push(this); }
else
{
while(!state.t_stack.empty())
{
if(state.t_stack.top()->getBlockLevel() == BlockLevel::Bracket)
{ state.t_stack.pop(); return; }
else
{
state.t_stack.top()->evaluate(state);
state.t_stack.pop();
}
}
}
}
public:
Bracket(String id, Type type) : Token(id), type_{type} {}
};
/*****************/
/*** NULL TYPE ***/
/*****************/
class NoOP : public Token
{
public:
// Do nothing when added to token stack
virtual void add(State& state) const {}
public:
NoOP(String id) : Token(id) {}
};
public: // Types
using TokenPtr = const Token*;
using NStack = StackBase<Number>;
using TStack = StackBase<TokenPtr>;
public: // Static Data
static const TokenPtr CoreTokens[];
private:
std::vector<Variable> vars;
public:
NStack n_stack;
TStack t_stack;
private: // Helper Functions
// Check if index is number
static bool IsValidNumber(Letter i)
{
return std::isdigit(i);
}
// Parse String into Number
static Number ParseNumber(String& input)
{
Letter* end;
double output = std::strtold(input, &end);
input = end;
return Number(output);
}
// Parse token
TokenPtr ParseToken(String& in)
{
// Check Core Tokens
for (const auto& i : CoreTokens)
if (i->checkID(in))
{ in += i->getIDLength(); return i; }
// Check Variables
for (const auto& i : vars)
if (i.checkID(in))
{ in += i.getIDLength(); return &i; }
/*** ERROR HANDLEING ***/
// Give Helpful Error
std::stringstream error; // Base Error
error << "Syntax Error! [Unknown Operator \"";
// Add Characters, Expensive, but its error handling
for(std::size_t i = 0;; ++i)
{
bool break_out = in[i] == '\0';
if(!break_out)
for (const auto& t : CoreTokens)
if (t->checkID(in + i))
{ break_out = true; break; }
if(break_out) break;
else error << in[i];
}
// If the end is null terminated
error << "\"]";
throw std::runtime_error(error.str());
return nullptr;
}
// Push token/numeber back
void push(const Number& in) { n_stack.emplace(in); }
void push(TokenPtr in) { if(in != nullptr) in->add(*this); }
void parse(String& in)
{
if(IsValidNumber(in[0])) { push(ParseNumber(in)); }
else { push(ParseToken(in)); }
}
Number getResult()
{
while(!t_stack.empty())
{
t_stack.top()->evaluate(*this);
t_stack.pop();
}
if(n_stack.size() != 1)
{
if(n_stack.size() < 1)
{ throw std::runtime_error("Invald Syntax! [No Numbers Given]"); }
else { throw std::runtime_error("Invald Syntax! [More Numbers than Operators]"); }
return Number(0);
}
else
{ return n_stack.top(); }
}
void clear()
{
while(!n_stack.empty()) n_stack.pop();
while(!t_stack.empty()) t_stack.pop();
}
public:
Number eval(String expr)
{
clear();
for(; *expr != '\0' && *expr != '\n';) { parse(expr); }
return getResult();
}
void setVariable(String name, const Number& val)
{
for (auto& i : vars)
if (i.checkID(name)) { i.setValue(val); return; }
vars.push_back(Variable(name, val));
}
public:
State(std::size_t alloc_size = 512) noexcept
: n_stack{StackBaseBackend<Number>(alloc_size)}
, t_stack{StackBaseBackend<TokenPtr>(alloc_size)} {}
};
template<class Number>
const typename State<Number>::TokenPtr State<Number>::CoreTokens[] = {
// Ignore Cbaracters
new NoOP(" "),
new NoOP(","),
// Parenthesis
new Bracket("(", Bracket::Type::Left),
new Bracket(")", Bracket::Type::Right),
new Bracket("[", Bracket::Type::Left),
new Bracket("]", Bracket::Type::Right),
new Bracket("{", Bracket::Type::Left),
new Bracket("}", Bracket::Type::Right),
// Operators
new Operator(
"+", [](auto a, auto b)
{ return a + b; }, 2, Operator::Order::Left
),
new Operator(
"-", [](auto a, auto b)
{ return a - b; }, 2, Operator::Order::Left
),
new Operator(
"*", [](auto a, auto b)
{ return a * b; }, 3, Operator::Order::Left
),
new Operator(
"/", [](auto a, auto b)
{ return a / b; }, 3, Operator::Order::Left
),
new Operator(
"%", [](auto a, auto b)
{ return std::fmod(a, b); }, 3, Operator::Order::Left
),
new Operator(
"^", [](auto a, auto b)
{ return std::pow(a, b); }, 4, Operator::Order::Right
),
// If Statements
new TrinaryFunction(
"if", [](auto cond, auto a, auto b)
{ return NumberToBool(cond) ? a : b; }
), new NoOP("else"),
// Constants
new Variable("pi", 3.14159265358979323846264338327950288419716939937510L),
new Variable("e", 2.71828182845904523536028747135266249775724709369995L),
new Variable("phi", 1.61803398874989484820458683436563811772030917980576L),
new Variable("true", 1.0L),
new Variable("false", -1.0L),
// Operators
new Operator(
"and", [](auto a, auto b)
{ return BoolToNumber<Number>(NumberToBool(a) && NumberToBool(b)); }, 0, Operator::Order::Left
),
new Operator(
"or", [](auto a, auto b)
{ return BoolToNumber<Number>(NumberToBool(a) || NumberToBool(b)); }, 0, Operator::Order::Left
),
new Operator(
"xor", [](auto a, auto b)
{ return BoolToNumber<Number>(NumberToBool(a) != NumberToBool(b)); }, 0, Operator::Order::Left
),
new UnaryFunction(
"not", [](auto a)
{ return BoolToNumber<Number>(!NumberToBool(a)); }
),
new Operator(
"==", [](auto a, auto b)
{ return BoolToNumber<Number>(a == b); }, 1, Operator::Order::Left
),
new Operator(
"!=", [](auto a, auto b)
{ return BoolToNumber<Number>(a != b); }, 1, Operator::Order::Left
),
new Operator(
"<=", [](auto a, auto b)
{ return BoolToNumber<Number>(a <= b); }, 1, Operator::Order::Left
),
new Operator(
">=", [](auto a, auto b)
{ return BoolToNumber<Number>(a >= b); }, 1, Operator::Order::Left
),
new Operator(
"<", [](auto a, auto b)
{ return BoolToNumber<Number>(a < b); }, 1, Operator::Order::Left
),
new Operator(
">", [](auto a, auto b)
{ return BoolToNumber<Number>(a > b); }, 1, Operator::Order::Left
),
// Functions
new UnaryFunction(
"!", [](auto in)
{ return std::tgamma(in + 1.0); }
),
new UnaryFunction(
"abs", [](auto in)
{ return std::abs(in); }
),
new UnaryFunction(
"neg", [](auto in)
{ return -in; }
),
new UnaryFunction(
"sqrt", [](auto in)
{ return std::sqrt(in); }
),
new UnaryFunction(
"cbrt", [](auto in)
{ return std::cbrt(in); }
),
new UnaryFunction(
"sin", [](auto in)
{ return std::sin(in); }
),
new UnaryFunction(
"cos", [](auto in)
{ return std::cos(in); }
),
new UnaryFunction(
"tan", [](auto in)
{ return std::tan(in); }
),
new UnaryFunction(
"ln", [](auto in)
{ return std::log(in); }
),
new UnaryFunction(
"log", [](auto in)
{ return std::log(in); }
),
new UnaryFunction(
"log10", [](auto in)
{ return std::log10(in); }
),
new UnaryFunction(
"log2", [](auto in)
{ return std::log2(in); }
),
new UnaryFunction(
"signum", [](auto in)
{ return (0 < in) ? 1.0 : -1.0; }
),
new UnaryFunction(
"bool", [](auto in)
{ return BoolToNumber<Number>(NumberToBool(in)); }
),
new UnaryFunction(
"round", [](auto in)
{ return std::round(in); }
),
new UnaryFunction(
"int", [](auto in)
{ return std::floor(in); }
),
new UnaryFunction(
"floor", [](auto in)
{ return std::floor(in); }
),
new UnaryFunction(
"ceil", [](auto in)
{ return std::ceil(in); }
),
// Binary Functions
new BinaryFunction(
"max", [](auto a, auto b)
{ return std::max(a, b); }
),
new BinaryFunction(
"min", [](auto a, auto b)
{ return std::min(a, b); }
),
new BinaryFunction(
"add", [](auto a, auto b)
{ return a + b; }
),
new BinaryFunction(
"sub", [](auto a, auto b)
{ return a - b; }
),
new BinaryFunction(
"mul", [](auto a, auto b)
{ return a * b; }
),
new BinaryFunction(
"div", [](auto a, auto b)
{ return a / b; }
),
new BinaryFunction(
"pow", [](auto a, auto b)
{ return std::pow(a, b); }
),
new BinaryFunction(
"mod", [](auto a, auto b)
{ return std::fmod(a, b); }
)
};
};
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment