Skip to content

Instantly share code, notes, and snippets.

@philzook58
Created March 14, 2021 04:07
Show Gist options
  • Save philzook58/b6779386a3df4aa2e033b5c804ee7547 to your computer and use it in GitHub Desktop.
Save philzook58/b6779386a3df4aa2e033b5c804ee7547 to your computer and use it in GitHub Desktop.
Egg categories
use egg::{*, rewrite as rw};
use wasm_bindgen::prelude::*;
use std::time::Duration;
define_language! {
enum CatLanguage {
// string variant with no children
"id" = IdMorph(Id),
"om" = OTimesM([Id; 2]),
"oo" = OTimesO([Id; 2]),
"." = Comp([Id; 2]),
"munit" = MUnit,
"swap" = Sigma([Id;2]),
"type" = Type(Id),
"dom" = Dom(Id),
"cod" = Cod(Id),
"hom" = Hom([Id;2]),
"ob" = Ob,
"dup" = Dup(Id),
"del" = Del(Id),
"pair" = Pair([Id;2]),
"proj1" = Proj1([Id;2]),
"proj2" = Proj2([Id;2]),
Symbol(Symbol),
}
}
type EGraph = egg::EGraph<CatLanguage, ()>;
type Pattern = egg::Pattern<CatLanguage>;
#[wasm_bindgen]
pub fn simplify(s : &str) -> String {
let mut rules : Vec<Rewrite<CatLanguage, ()>> = vec![
vec![rw!( "dom(hom(a, b)) => a" ; "(dom (hom ?a ?b))" => "?a" )],
vec![rw!( "cod(hom(a, b)) => b" ; "(cod (hom ?a ?b))" => "?b" )],
vec![rw!( "type(id(a)) => hom(a, a)" ; "(type (id ?a))" => "(hom ?a ?a)" )],
vec![rw!( "type(f . g) => hom(dom(type(f)), cod(type(g)))" ; "(type (. ?f ?g))" => "(hom (dom (type ?f)) (cod (type ?g)))" )],
vec![rw!( "type(f om g) => hom(dom(type(f)) oo dom(type(g)), cod(type(f)) oo cod(type(g)))" ; "(type (om ?f ?g))" => "(hom (oo (dom (type ?f)) (dom (type ?g))) (oo (cod (type ?f)) (cod (type ?g))))" )],
vec![rw!( "type(a oo b) => :ob" ; "(type (oo ?a ?b))" => "ob" )],
vec![rw!( "type(munit()) => :ob" ; "(type munit)" => "ob" )],
vec![rw!( "type(swap(a, b)) => hom(a oo b, b oo a)" ; "(type (swap ?a ?b))" => "(hom (oo ?a ?b) (oo ?b ?a))" )],
vec![rw!( "type((del)(a)) => hom(a, munit())" ; "(type (del ?a))" => "(hom ?a munit)" )],
vec![rw!( "type(dup(a)) => hom(a, a oo a)" ; "(type (dup ?a))" => "(hom ?a (oo ?a ?a))" )],
vec![rw!( "type(pair(f, g)) => hom(dom(type(f)), cod(type(f)) oo cod(type(g)))" ; "(type (pair ?f ?g))" => "(hom (dom (type ?f)) (oo (cod (type ?f)) (cod (type ?g))))" )],
vec![rw!( "type(proj1(a, b)) => hom(a oo b, a)" ; "(type (proj1 ?a ?b))" => "(hom (oo ?a ?b) ?a)" )],
vec![rw!( "type(proj2(a, b)) => hom(a oo b, b)" ; "(type (proj2 ?a ?b))" => "(hom (oo ?a ?b) ?b)" )],
vec![rw!( "f . id(b) => f" ; "(. ?f (id ?b))" => "?f" )],
vec![rw!( "id(a) . f => f" ; "(. (id ?a) ?f)" => "?f" )],
vec![rw!( "a oo munit() => a" ; "(oo ?a munit)" => "?a" )],
vec![rw!( "munit() oo a => a" ; "(oo munit ?a)" => "?a" )],
rw!( "f . (g . h) == (f . g) . h" ; "(. ?f (. ?g ?h))" <=> "(. (. ?f ?g) ?h)" ),
vec![rw!( "id(munit()) om f => f" ; "(om (id munit) ?f)" => "?f" )],
vec![rw!( "f om id(munit()) => f" ; "(om ?f (id munit))" => "?f" )],
rw!( "a oo (b oo c) == (a oo b) oo c" ; "(oo ?a (oo ?b ?c))" <=> "(oo (oo ?a ?b) ?c)" ),
rw!( "f om (h om j) == (f om h) om j" ; "(om ?f (om ?h ?j))" <=> "(om (om ?f ?h) ?j)" ),
rw!( "id(a oo b) == id(a) om id(b)" ; "(id (oo ?a ?b))" <=> "(om (id ?a) (id ?b))" ),
vec![rw!( "(f . g) om (p . q) => (f om p) . (g om q)" ; "(om (. ?f ?g) (. ?p ?q))" => "(. (om ?f ?p) (om ?g ?q))" )],
rw!( "swap(a, b) . swap(b, a) == id(a oo b)" ; "(. (swap ?a ?b) (swap ?b ?a))" <=> "(id (oo ?a ?b))" ),
rw!( "(swap(a, b) om id(c)) . (id(b) om swap(a, c)) == swap(a, b oo c)" ; "(. (om (swap ?a ?b) (id ?c)) (om (id ?b) (swap ?a ?c)))" <=> "(swap ?a (oo ?b ?c))" ),
rw!( "(id(a) om swap(b, c)) . (swap(a, c) om id(b)) == swap(a oo b, c)" ; "(. (om (id ?a) (swap ?b ?c)) (om (swap ?a ?c) (id ?b)))" <=> "(swap (oo ?a ?b) ?c)" ),
rw!( "swap(a, munit()) == id(a)" ; "(swap ?a munit)" <=> "(id ?a)" ),
rw!( "swap(munit(), a) == id(a)" ; "(swap munit ?a)" <=> "(id ?a)" ),
vec![rw!( "swap(munit(), munit()) => id(munit() oo munit())" ; "(swap munit munit)" => "(id (oo munit munit))" )],
rw!( "dup(a) . ((del)(a) om id(a)) == id(a)" ; "(. (dup ?a) (om (del ?a) (id ?a)))" <=> "(id ?a)" ),
rw!( "dup(a) . (id(a) om (del)(a)) == id(a)" ; "(. (dup ?a) (om (id ?a) (del ?a)))" <=> "(id ?a)" ),
rw!( "dup(a) . swap(a, a) == dup(a)" ; "(. (dup ?a) (swap ?a ?a))" <=> "(dup ?a)" ),
rw!( "(dup(a) om dup(b)) . ((id(a) om swap(a, b)) om id(b)) == dup(a oo b)" ; "(. (om (dup ?a) (dup ?b)) (om (om (id ?a) (swap ?a ?b)) (id ?b)))" <=> "(dup (oo ?a ?b))" ),
rw!( "dup(a) . (dup(a) om id(a)) == dup(a) . (id(a) om dup(a))" ; "(. (dup ?a) (om (dup ?a) (id ?a)))" <=> "(. (dup ?a) (om (id ?a) (dup ?a)))" ),
rw!( "(del)(a oo b) == (del)(a) om (del)(b)" ; "(del (oo ?a ?b))" <=> "(om (del ?a) (del ?b))" ),
rw!( "dup(munit()) == id(munit())" ; "(dup munit)" <=> "(id munit)" ),
rw!( "(del)(munit()) == id(munit())" ; "(del munit)" <=> "(id munit)" ),
vec![rw!( "pair(f, k) => dup(dom(type(f))) . (f om k)" ; "(pair ?f ?k)" => "(. (dup (dom (type ?f))) (om ?f ?k))" )],
rw!( "proj1(a, b) == id(a) om (del)(b)" ; "(proj1 ?a ?b)" <=> "(om (id ?a) (del ?b))" ),
rw!( "proj2(a, b) == (del)(a) om id(b)" ; "(proj2 ?a ?b)" <=> "(om (del ?a) (id ?b))" ),
vec![rw!( "f . (del)(b) => (del)(dom(type(f)))" ; "(. ?f (del ?b))" => "(del (dom (type ?f)))" )],
vec![rw!( "f . dup(b) => dup(dom(type(f))) . (f om f)" ; "(. ?f (dup ?b))" => "(. (dup (dom (type ?f))) (om ?f ?f))" )],
vec![rw!( "dup(a) . (f om f) => f . dup(cod(type(f)))" ; "(. (dup ?a) (om ?f ?f))" => "(. ?f (dup (cod (type ?f))))" )],
].concat();
let fcod : Pattern = "(cod (type ?f))".parse().unwrap();
let gdom : Pattern = "(dom (type ?g))".parse().unwrap();
let pcod : Pattern = "(cod (type ?p))".parse().unwrap();
let qdom : Pattern = "(dom (type ?q))".parse().unwrap();
rules.push(rw!(
"interchange";
"(. (om ?f ?p) (om ?g ?q))" => "(om (. ?f ?g) (. ?p ?q))"
if ConditionEqual(fcod, gdom)
if ConditionEqual(pcod, qdom)
));
rules.push(rw!(
"(f ⊗ₘ h) ⋅ σ(a, b) ";
"(. (om ?f ?h) (swap ?a ?b))" => "(. (swap (dom (type ?f)) (dom (type ?h))) (om ?h ?f))"
if ConditionEqual::parse("(cod (type ?f))", "?a")
if ConditionEqual::parse("(cod (type ?h))", "?b")
));
rules.push(rw!(
"σ(c, d) ⋅ (h ⊗ₘ f) ";
"(. (swap ?c ?d) (om ?h ?f) )" => "(. (om ?f ?h) (swap (cod (type ?f)) (cod (type ?h))) )"
if ConditionEqual::parse("(dom (type ?f))", "?c")
if ConditionEqual::parse("(cod (type ?h))", "?d")
));
rules.push(rw!(
" Δ(a) ⋅ (f ⊗ₘ k)";
"(. (dup ?a) (om ?f ?k))" => "(pair ?f ?k)"
if ConditionEqual::parse("(dom (type ?f))", "?a")
if ConditionEqual::parse("(dom (type ?k))", "?a")
));
//let typf = "(type f)".parse().unwrap();
//let homab = "(hom a b)".parse().unwrap();
let start = match s.parse(){
Ok(s) => s,
Err(e) => return e
};
let mut runner = Runner::default().with_expr(&start);
// More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html
//.with_iter_limit(400)
//.with_node_limit(2_000_000)
// .with_time_limit(Duration::new(60,0));
//let typf = runner.egraph.add_expr(&typf);
//let homab = runner.egraph.add_expr(&homab);
//runner.egraph.union(typf , homab);
let runner = runner.run(&rules);
runner.print_report();
let mut extractor = Extractor::new(&runner.egraph, AstSize);
let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
println!("best cost: {}, best expr {}", best_cost, best_expr);
return best_expr.to_string();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment