Skip to content

Instantly share code, notes, and snippets.

@wilcoxjay
Created July 14, 2020 02:00
Show Gist options
  • Save wilcoxjay/9df41a92651cc2162541ae7a61954d21 to your computer and use it in GitHub Desktop.
Save wilcoxjay/9df41a92651cc2162541ae7a61954d21 to your computer and use it in GitHub Desktop.
use egg::*;
type EGraph = egg::EGraph<FibLang, ConstantFold>;
define_language! {
enum FibLang {
"+" = Add([Id; 2]),
"-" = Sub([Id; 2]),
"fib" = Fib([Id; 1]),
Int(u64),
Symbol(egg::Symbol),
}
}
#[derive(Default)]
pub struct ConstantFold;
impl Analysis<FibLang> for ConstantFold {
type Data = Option<u64>;
fn merge(&self, to: &mut Self::Data, from: Self::Data) -> bool {
if let (Some(c1), Some(c2)) = (to.as_ref(), from.as_ref()) {
assert_eq!(c1, c2);
}
merge_if_different(to, to.or(from))
}
fn make(egraph: &EGraph, enode: &FibLang) -> Self::Data {
let x = |i: &Id| egraph[*i].data;
Some(match enode {
FibLang::Int(c) => *c,
FibLang::Add([a, b]) => x(a)? + x(b)?,
FibLang::Sub([a, b]) => x(a)? - x(b)?,
_ => return None,
})
}
fn modify(egraph: &mut EGraph, id: Id) {
let class = &mut egraph[id];
if let Some(c) = class.data {
let added = egraph.add(FibLang::Int(c));
let (id, _did_something) = egraph.union(id, added);
// to not prune, comment this out
egraph[id].nodes.retain(|n| n.is_leaf());
assert!(
!egraph[id].nodes.is_empty(),
"empty eclass! {:#?}",
egraph[id]
);
#[cfg(debug_assertions)]
egraph[id].assert_unique_leaves();
}
}
}
fn is_const_bigger_one(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let var = var.parse().unwrap();
move |egraph, _, subst| {
egraph[subst[var]]
.nodes
.iter()
.any(|n| match n { FibLang::Int(v) => *v > 1, _ => false })
}
}
fn rules() -> Vec<Rewrite<FibLang, ConstantFold>> {
vec![
rewrite!("fib-defn0"; "(fib 0)" => "0"),
rewrite!("fib-defn1"; "(fib 1)" => "1"),
rewrite!("fib-defn2"; "(fib ?n)" => "(+ (fib (- ?n 1)) (fib (- ?n 2)))" if is_const_bigger_one("?n")),
]
}
test_fn! {
fib3, rules(),
"(fib 3)" => "2"
}
test_fn! {
fib10, rules(),
"(fib 10)" => "55"
}
test_fn! {
fib20, rules(),
"(fib 20)" => "6765"
}
test_fn! {
fib30, rules(),
"(fib 30)" => "832040"
}
test_fn! {
fib40, rules(),
"(fib 40)" => "102334155"
}
test_fn! {
fib50, rules(),
"(fib 50)" => "12586269025"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment