Skip to content

Instantly share code, notes, and snippets.

@anlun
Last active October 25, 2022 06:17
Show Gist options
  • Save anlun/a328a8194ccfb75fb59c0c12403e56bd to your computer and use it in GitHub Desktop.
Save anlun/a328a8194ccfb75fb59c0c12403e56bd to your computer and use it in GitHub Desktop.
Attempt to encode relational algebra rules in Egg
use egg::*;
define_language! {
enum RelLanguage {
"top" = Top,
"bot" = Bot,
";;" = Seq([Id; 2]),
"+" = CT(Id),
"?" = RT(Id),
"*" = CRT(Id),
Symbol(Symbol),
}
}
fn make_rules() -> Vec<Rewrite<RelLanguage, ()>> {
let mut rules =
vec![ rewrite!("ct_end" ; "(+ ?a)" <=> "(;; (* ?a) ?a)"),
rewrite!("ct_begin"; "(+ ?a)" <=> "(;; ?a (* ?a))"),
rewrite!("rt_end" ; "(* ?a)" <=> "(;; (* ?a) (? ?a))"),
rewrite!("rt_begin"; "(* ?a)" <=> "(;; (? ?a) (* ?a))"),
].concat();
rules.extend(
vec![
rewrite!("seqA" ; "(;; ?a (;; ?b ?c))" => "(;; (;; ?a ?b) ?c)"),
rewrite!("seq_id_l" ; "(;; top ?a)" => "?a"),
rewrite!("seq_id_r" ; "(;; ?a top)" => "?a"),
rewrite!("seq_false_l"; "(;; bot ?a)" => "bot"),
rewrite!("seq_false_r"; "(;; ?a bot)" => "bot"),
rewrite!("seq_ct" ; "(;; (+ ?a) (+ ?a))" => "(+ ?a)"),
rewrite!("ct_rt" ; "(;; (+ ?a) (* ?a))" => "(+ ?a)"),
rewrite!("rt_ct" ; "(;; (* ?a) (+ ?a))" => "(+ ?a)"),
]);
rules
}
/// parse an expression, simplify it using egg, and pretty print it back out
fn simplify(s: &str) -> String {
// parse the expression, the type annotation tells it which Language to use
let expr: RecExpr<RelLanguage> = s.parse().unwrap();
// simplify the expression using a Runner, which creates an e-graph with
// the given expression and runs the given rules over it
let runner = Runner::default().with_expr(&expr)
.run(&make_rules());
// the Runner knows which e-class the expression given with `with_expr` is in
let root = runner.roots[0];
// use an Extractor to pick the best element of the root eclass
let extractor = Extractor::new(&runner.egraph, AstSize);
let (best_cost, best) = extractor.find_best(root);
println!("Simplified {} to {} with cost {}", expr, best, best_cost);
best.to_string()
}
fn main() {
use std::time::Instant;
let now = Instant::now();
assert_eq!(simplify("(;; (* a) (;; (? a) (? a)))"), "(* a)");
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment