Skip to content

Instantly share code, notes, and snippets.

@thebusytypist thebusytypist/ast.h
Last active Aug 2, 2016

Embed
What would you like to do?
Pattern Matching in C++
#ifndef AST_H
#define AST_H
#include <string>
#include <memory>
#include <vector>
#include <ostream>
class Expr {
public:
Expr* inData(int i) const {
return data[i].get();
}
virtual void print(std::ostream& os) const = 0;
std::vector<std::unique_ptr<Expr>> data;
};
class Symbol : public Expr {
public:
explicit Symbol(const std::string& n) : name(n) {}
void print(std::ostream& os) const override {
os << name;
}
std::string name;
};
class Variable : public Symbol {
public:
explicit Variable(const std::string& n) : Symbol(n) {}
std::string name;
};
class Constant : public Expr {
public:
explicit Constant(int x) : value(x) {}
void print(std::ostream& os) const override {
os << value;
}
int value;
};
class AddExpr : public Expr {
public:
AddExpr(std::unique_ptr<Expr> a, std::unique_ptr<Expr> b) {
data.push_back(std::move(a));
data.push_back(std::move(b));
}
void print(std::ostream& os) const override {
data[0]->print(os);
os << "+";
data[1]->print(os);
}
};
class SubExpr : public Expr {
public:
SubExpr(std::unique_ptr<Expr> a, std::unique_ptr<Expr> b) {
data.push_back(std::move(a));
data.push_back(std::move(b));
}
void print(std::ostream& os) const override {
data[0]->print(os);
os << "-";
data[1]->print(os);
}
};
inline void print(std::ostream& os, const Expr* expr) {
expr->print(os);
}
#endif
#ifndef ASTFWD_H
#define ASTFWD_H
class Expr;
class Symbol;
class Variable;
class Constant;
class AddExpr;
class SubExpr;
#endif
#include "ast.h"
#include "pattern.h"
#include <string>
#include <tuple>
#include <iostream>
using namespace std;
int main() {
unique_ptr<Expr> expr =
make_unique<AddExpr>(
make_unique<Symbol>("first"),
make_unique<SubExpr>(
make_unique<Symbol>("second"),
make_unique<Constant>(7)
));
auto data = make_tuple<int, string>(-1, "");
typedef cons2<AddExpr,
Expr,
cons2<SubExpr,
out<Symbol, 1>,
out<Constant, 0>>> PD;
cout << "expected: 1; actual: "
<< Pattern<PD>::match(expr.get(), data) << "\n";
cout << get<0>(data) << " " << get<1>(data) << "\n";
typedef cons2<AddExpr, Expr, Expr> P0;
cout << "expected: 1; actual: "
<< Pattern<P0>::match(expr.get(), data) << "\n";
typedef cons2<SubExpr, Expr, Expr> P1;
cout << "expected: 0; actual: "
<< Pattern<P1>::match(expr.get(), data) << "\n";
typedef cons2<AddExpr,
Expr,
cons2<SubExpr,
Expr,
Expr>> P2;
cout << "expected: 1; actual: "
<< Pattern<P2>::match(expr.get(), data) << "\n";
return 0;
}
#ifndef PATTERN_H
#define PATTERN_H
#include "astfwd.h"
#include <utility>
// Constructors
template <typename Root, typename T0, typename T1> class cons2;
template <typename Root, typename T0> class cons1;
// Guards
template <typename N, int i> class out;
// Dispatchers
template <typename G> class Dispatcher;
// Descender
// Descend from S to T.
template <typename S, typename T> class Descender;
// General descending
template <typename S, typename T>
class Descender {
public:
static const T* descend(const S* n) {
return dynamic_cast<const T*>(n);
}
};
// Pattern
template <typename P> class Pattern;
// Destruct composition patterns.
template <typename Root, typename T0, typename T1>
class Pattern <cons2<Root, T0, T1>> {
public:
template <typename N, typename Tuple>
static bool match(const N* n, Tuple& t) {
return Pattern<Root>::match(n, t) &&
Pattern<T0>::match(n->inData(0), t) &&
Pattern<T1>::match(n->inData(1), t);
}
};
template <typename Root, typename T0>
class Pattern <cons1<Root, T0>> {
public:
template <typename N, typename Tuple>
static bool match(const N* n, Tuple& t) {
return Pattern<Root>::match(n, t) &&
Pattern<T0>::match(n->inData(0), t);
}
};
// Destruct guards.
template <template <typename T, int i> class G, typename Root, int i>
class Pattern <G<Root, i>> {
public:
template <typename N, typename Tuple>
static bool match(const N* n, Tuple& t) {
bool m = Pattern<Root>::match(n, t);
if (m) {
const Root* m = Descender<N, Root>::descend(n);
bool d = Dispatcher<G<Root, i>>::dispatch(m, t);
return d;
}
return false;
}
};
// Match the leaf node.
template <typename P>
class Pattern {
public:
template <typename N, typename Tuple>
static bool match(const N* n, Tuple& t) {
return Descender<N, P>::descend(n) != nullptr;
}
};
// Specialized dispatchers
template <int i>
class Dispatcher <out<Constant, i>> {
public:
template <typename Tuple>
static bool dispatch(const Constant* n, Tuple& t) {
std::get<i>(t) = n->value;
return true;
}
};
template <int i>
class Dispatcher <out<Symbol, i>> {
public:
template <typename Tuple>
static bool dispatch(const Symbol* n, Tuple& t) {
std::get<i>(t) = n->name;
return true;
}
};
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.