Skip to content

Instantly share code, notes, and snippets.

@deque-blog
Created March 30, 2017 16:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deque-blog/d683c26256d9724fc9edadab45c2cc08 to your computer and use it in GitHub Desktop.
Save deque-blog/d683c26256d9724fc9edadab45c2cc08 to your computer and use it in GitHub Desktop.
#include <algorithm>
#include <boost/algorithm/string/join.hpp>
#include <boost/range/algorithm.hpp>
#include <boost/range/adaptors.hpp>
#include <boost/range/numeric.hpp>
#include <boost/variant.hpp>
#include <cassert>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
using namespace boost::adaptors;
//--------------------------------------------------------
// Recursive operation
//--------------------------------------------------------
using nb = int;
using id = std::string;
struct add_tag {};
struct mul_tag {};
template<typename Tag, typename R>
struct op
{
op() = default;
template<typename Range>
explicit op (Range const& rng) : m_rands(rng.begin(), rng.end()) {}
std::vector<R> const& rands() const { return m_rands; }
private:
std::vector<R> m_rands;
};
template<typename R> using add_op = op<add_tag, R>;
template<typename R> using mul_op = op<mul_tag, R>;
template<typename R>
using expression_r = boost::variant<int, id, add_op<R>, mul_op<R>>;
struct expression : boost::recursive_wrapper<expression_r<expression>>
{
using boost::recursive_wrapper<expression_r<expression>>::recursive_wrapper;
};
//--------------------------------------------------------
// Smart constructors
//--------------------------------------------------------
expression cst(int i) { return expression(i); };
expression var(id id) { return expression(id); };
expression add(std::vector<expression> const& rands)
{
return expression(add_op<expression>{ rands });
}
expression mul(std::vector<expression> const& rands)
{
return expression(mul_op<expression>{ rands });
}
//--------------------------------------------------------
// Query
//--------------------------------------------------------
template <typename T>
int const* get_as_cst(expression_r<T> const& e)
{
return boost::get<int>(&e);
}
template <typename T>
id const* get_as_var(expression_r<T> const& e)
{
return boost::get<id>(&e);
}
template <typename T>
add_op<T> const* get_as_add(expression_r<T> const& e)
{
return boost::get<add_op<T>>(&e);
}
template <typename T>
mul_op<T> const* get_as_mul(expression_r<T> const& e)
{
return boost::get<mul_op<T>>(&e);
}
void throw_missing_pattern_matching_clause()
{
throw std::logic_error("Missing case in pattern matching");
}
//--------------------------------------------------------
// FUNCTOR INSTANCE
//--------------------------------------------------------
template<typename A, typename M>
auto fmap(M map, expression_r<A> const& e)
{
using B = decltype(map(std::declval<A>()));
using Out = expression_r<B>;
if (auto* o = get_as_add(e))
return Out(add_op<B>(o->rands() | transformed(map)));
if (auto* o = get_as_mul(e))
return Out(mul_op<B>(o->rands() | transformed(map)));
if (auto* i = get_as_cst(e)) return Out(*i);
if (auto* v = get_as_var(e)) return Out(*v);
throw_missing_pattern_matching_clause();
}
//--------------------------------------------------------
// CATAMORPHISM
//--------------------------------------------------------
template<typename Out, typename Algebra>
Out cata(Algebra f, expression const& ast)
{
return f(
fmap(
[f](expression const& e) -> Out {
return cata<Out>(f, e);
},
ast.get()));
}
//--------------------------------------------------------
// PARAMORPHISM
// The algebra f now takes an expression_r of (Out, expression)
// This allows to access the context of the evaluation
//--------------------------------------------------------
template<typename Out, typename Algebra>
Out para(Algebra f, expression const& ast)
{
return f(
fmap(
[f](expression const& e) -> std::pair<Out, expression const*> {
return { para<Out>(f, e), &e };
},
ast.get()));
}
//--------------------------------------------------------
// DISPLAY
//--------------------------------------------------------
template<typename Tag>
std::string print_op(op<Tag, std::string> const& e, std::string const& op_repr)
{
return std::string("(") + op_repr + " " + boost::algorithm::join(e.rands(), " ") + ")";
}
std::string print_alg(expression_r<std::string> const& e)
{
if (auto* o = get_as_add(e)) return print_op(*o, "+");
if (auto* o = get_as_mul(e)) return print_op(*o, "*");
if (auto* i = get_as_cst(e)) return std::to_string(*i);
if (auto* v = get_as_var(e)) return *v;
throw_missing_pattern_matching_clause();
}
//--------------------------------------------------------
// DISPLAY (INFIX)
//--------------------------------------------------------
std::string print_infix_op_bad(op<add_tag, std::string> const& e)
{
return boost::algorithm::join(e.rands(), " + ");
}
std::string with_parens(std::string const& s)
{
return std::string("(") + s + ")";
}
std::string print_infix_op_bad(op<mul_tag, std::string> const& e)
{
return boost::algorithm::join(e.rands() | transformed(with_parens), " * ");
}
std::string print_infix_bad(expression_r<std::string> const& e)
{
if (auto* o = get_as_add(e)) return print_infix_op_bad(*o);
if (auto* o = get_as_mul(e)) return print_infix_op_bad(*o);
if (auto* i = get_as_cst(e)) return std::to_string(*i);
if (auto* v = get_as_var(e)) return *v;
throw_missing_pattern_matching_clause();
}
//--------------------------------------------------------
std::string print_op_infix(op<add_tag, std::pair<std::string, expression const*>> const& e)
{
auto fst = [](auto const& e) { return e.first; };
return boost::algorithm::join(e.rands() | transformed(fst), " + ");
}
std::string print_op_infix(op<mul_tag, std::pair<std::string, expression const*>> const& e)
{
auto wrap_addition = [](auto const& sub_expr) {
if (get_as_add(sub_expr.second->get()))
return with_parens(sub_expr.first);
return sub_expr.first;
};
return boost::algorithm::join(e.rands() | transformed(wrap_addition), " * ");
}
std::string print_infix(expression_r<std::pair<std::string, expression const*>> const& e)
{
if (auto* o = get_as_add(e)) return print_op_infix(*o);
if (auto* o = get_as_mul(e)) return print_op_infix(*o);
if (auto* i = get_as_cst(e)) return std::to_string(*i);
if (auto* v = get_as_var(e)) return *v;
throw_missing_pattern_matching_clause();
}
//--------------------------------------------------------
// EVALUATION
//--------------------------------------------------------
using env = std::map<id, nb>;
auto eval_alg(env const& env)
{
return [&env] (expression_r<int> const& e)
{
if (auto* o = get_as_add(e))
return boost::accumulate(o->rands(), 0, std::plus<int>());
if (auto* o = get_as_mul(e))
return boost::accumulate(o->rands(), 1, std::multiplies<int>());
if (auto* v = get_as_var(e)) return env.find(*v)->second;
if (auto* i = get_as_cst(e)) return *i;
throw_missing_pattern_matching_clause();
};
}
int eval(env const& env, expression const& expr)
{
return cata<int>(eval_alg(env), expr);
}
//--------------------------------------------------------
// DEPENDENCIES
//--------------------------------------------------------
template<typename Tag>
std::set<id> join_sets(op<Tag, std::set<id>> const& op)
{
std::set<id> out;
for (auto r: op.rands())
out.insert(r.begin(), r.end());
return out;
}
std::set<id> dependencies_alg(expression_r<std::set<id>> const& e)
{
if (auto* o = get_as_add(e)) return join_sets(*o);
if (auto* o = get_as_mul(e)) return join_sets(*o);
if (auto* v = get_as_var(e)) return {*v};
return {};
}
std::set<id> dependencies(expression const& e)
{
return cata<std::set<id>>(dependencies_alg, e);
}
//--------------------------------------------------------
// OPTIMIZATIONS
//--------------------------------------------------------
template<typename Tag, typename Step>
expression optimize_op(op<Tag, expression> const& e, int neutral, Step step)
{
int res = neutral;
std::vector<expression> subs;
for (expression const& sub: e.rands())
{
if (auto* i = get_as_cst(sub.get()))
{
res = step(res, *i);
}
else
{
subs.push_back(sub);
}
}
if (subs.empty()) return cst(res);
if (res != neutral) subs.push_back(cst(res));
if (subs.size() == 1) return subs.front();
return expression(op<Tag, expression>(subs));
}
template<typename Range>
bool has_zero(Range const& subs)
{
return end(subs) != boost::find_if(subs, [](expression const& sub) {
auto* i = get_as_cst(sub.get());
return i && *i == 0;
});
}
expression opt_add_alg(expression_r<expression> const& e)
{
if (auto* op = get_as_add(e))
return optimize_op(*op, 0, std::plus<int>());
return e;
}
expression opt_mul_alg(expression_r<expression> const& e)
{
if (auto* op = get_as_mul(e))
{
if (has_zero(op->rands()))
return cst(0);
return optimize_op(*op, 1, std::multiplies<int>());
}
return e;
}
expression optimize_alg(expression_r<expression> const& e)
{
return opt_mul_alg(opt_add_alg(e).get());
}
//--------------------------------------------------------
// PARTIAL EVAL
//--------------------------------------------------------
auto partial_eval_alg(env const& env)
{
return [&env] (expression_r<expression> const& e) -> expression
{
if (auto* v = get_as_var(e))
{
auto it = env.find(*v);
if (it != env.end()) return cst(it->second);
return var(*v);
}
return e;
};
}
expression partial_eval(env const& env, expression const& e)
{
return cata<expression>(
[&env](expression_r<expression> const& e) -> expression {
return optimize_alg(partial_eval_alg(env)(e).get());
},
e);
}
//--------------------------------------------------------
// EVALUATION (Different implementations)
//--------------------------------------------------------
void throw_missing_variables(std::set<id> const& dependencies)
{
std::ostringstream s;
for (auto const& d: dependencies)
s << d << " ";
throw std::logic_error(s.str());
}
int eval_2(env const& env, expression const& e)
{
auto reduced = partial_eval(env, e);
if (auto* i = get_as_cst(reduced.get())) return *i;
throw_missing_variables(dependencies(reduced));
}
//--------------------------------------------------------
// Tests
//--------------------------------------------------------
int main()
{
expression e = add({
cst(1),
cst(2),
mul({cst(0), var("x"), var("y")}),
mul({cst(1), var("y"), add({cst(2), var("x")})}),
add({cst(0), var("x")})
});
env full_env = {{"x", 1}, {"y", 2}};
std::cout << cata<std::string>(print_alg, e) << std::endl;
std::cout << cata<std::string>(print_infix_bad, e) << std::endl;
std::cout << para<std::string>(print_infix, e) << std::endl;
std::cout << eval(full_env, e) << std::endl;
std::cout << eval_2(full_env, e) << std::endl;
auto e2 = cata<expression>(partial_eval_alg(full_env), e);
env empty_env;
std::cout << cata<std::string>(print_alg, e2) << std::endl; //TODO - chain optimize and partial
std::cout << eval(empty_env, e2) << std::endl;
std::cout << eval_2(empty_env, e2) << std::endl;
auto e3 = cata<expression>(optimize_alg, e);
std::cout << cata<std::string>(print_alg, e3) << std::endl;
std::cout << eval(full_env, e3) << std::endl;
std::cout << eval_2(full_env, e3) << std::endl;
try {
eval_2(empty_env, e);
} catch (std::logic_error const& e) {
std::cout << e.what() << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment