Created
February 24, 2024 06:20
-
-
Save tmoux/06ed21574f402f56cad5fe0b9bcdfb81 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Mathlib.Util.AtomM | |
import Mathlib.Data.Nat.Basic | |
open Lean Elab Meta Tactic Mathlib.Tactic | |
open Lean.Expr | |
-- Monoid simplifier tactic | |
-- Normalize expressions by flattening to a list of atoms and eliminating identity elements. | |
namespace Monoid | |
structure Context where | |
α : Expr -- The type of the expression | |
univ : Level -- The universe level of α | |
inst : Expr -- The Monoid instance | |
α1 : Expr -- the identity of the monoid | |
def mkContext (e : Expr) : MetaM Context := do | |
let α ← inferType e | |
let c ← synthInstance (← mkAppM ``Monoid #[α]) | |
let u ← mkFreshLevelMVar | |
let α1 ← Expr.ofNat α 1 | |
_ ← isDefEq (.sort (.succ u)) (← inferType α) | |
return ⟨ α, u, c, α1 ⟩ | |
-- applies functions of form ∀ {u} {a : Type u}, _ | |
def Context.app (c : Context) (n : Name) : Array Expr → Expr := | |
mkAppN ((@Expr.const n [c.univ]).app c.α) | |
-- applies functions of form ∀ {u} {a : Type u} [Monoid α], _ | |
def Context.appInst (c : Context) (n : Name) : Array Expr → Expr := | |
mkAppN (((@Expr.const n [c.univ]).app c.α).app c.inst) | |
abbrev M := ReaderT Context AtomM | |
inductive exp (α : Type u) : Type u where | |
| atom : α → exp α | |
| mult : exp α → exp α → exp α | |
| mempty : exp α | |
@[simp] def exp.denote [Monoid α] (e : exp α) : α := | |
match e with | |
| atom x => x | |
| mult x y => exp.denote x * exp.denote y | |
| mempty => 1 | |
@[simp] def exp.normalize [Monoid α] : exp α → List α | |
| atom x => [x] | |
| mult x y => exp.normalize x ++ exp.normalize y | |
| mempty => [] | |
@[simp] def nf.denote [Monoid α] : List α → α | |
| [] => 1 | |
| x :: xs => x * nf.denote xs | |
lemma denote_mult_commute [Monoid α] (xs ys : List α) : | |
nf.denote xs * nf.denote ys = nf.denote (xs ++ ys) := by | |
induction xs | |
case nil => simp | |
case cons IH => simp [mul_assoc, IH] | |
theorem normalize_correct [Monoid α] (e : exp α) : nf.denote (exp.normalize e) = exp.denote e := by | |
induction e | |
case atom => simp | |
case mult xIH yIH => simp [xIH, yIH, ←denote_mult_commute] | |
case mempty => simp | |
theorem monoid_reflect [Monoid α] (a b : exp α) : | |
nf.denote (exp.normalize a) = nf.denote (exp.normalize b) → | |
exp.denote a = exp.denote b := by | |
repeat rw [normalize_correct] | |
simp | |
partial def exp.reify (e : Expr) : M Expr := do | |
match getAppFnArgs e with | |
| (``HMul.hMul, #[_, _, _, _, e₁, e₂]) => do | |
let e₁' ← exp.reify e₁ | |
let e₂' ← exp.reify e₂ | |
return (← read).app ``exp.mult #[e₁', e₂'] | |
| _ => | |
if ← isDefEq e (← read).α1 then | |
return (← read).app ``exp.mempty #[] | |
else | |
return (← read).app ``exp.atom #[e] | |
syntax (name := monoid) "monoid" : tactic | |
elab_rules : tactic | `(tactic| monoid) => withMainContext do | |
let some (_, e₁, e₂) := (← whnfR <| ← getMainTarget).eq? | |
| throwError "monoid: requires an equality goal" | |
let c ← mkContext e₁ | |
closeMainGoal <| ← AtomM.run .default <| ReaderT.run (r := c) do | |
let t₁ ← exp.reify e₁ -- Reify the expressions into exp | |
let t₂ ← exp.reify e₂ | |
let n₁ := (← read).appInst ``exp.normalize #[t₁] -- Normalize the expressions | |
let n₂ := (← read).appInst ``exp.normalize #[t₂] | |
unless ← isDefEq n₁ n₂ do | |
throwError "monoid: normalized forms not equal" | |
let m₁ := (← read).appInst ``nf.denote #[n₁] | |
let eq ← mkAppM ``Eq.refl #[m₁] | |
mkAppM ``monoid_reflect #[t₁, t₂, eq] -- apply the reflect theorem | |
-- Examples | |
instance : Monoid (List α) where | |
mul := List.append | |
one := [] | |
mul_assoc := List.append_assoc | |
one_mul := List.nil_append | |
mul_one := List.append_nil | |
def f := 3 | |
lemma ex1 : 1 * 2 * 3 = 1 * (2 * f) * 1 * 1 := by | |
monoid | |
lemma ex2 : ∀ (a b c : ℕ), a * 1 * b * (a + c) = a * (b * (a + c)):= by | |
intros | |
monoid | |
lemma ex3 : ∀ (a b c d : List Bool), [] * a * (b * c) * d = a * (b * c * d) := by | |
intros | |
monoid | |
lemma ex4 : ∀ (a b c d : ℕ), a * b * (b * d * c * a * 1 * d) = 1 * 1 * a * (1 * b) * (b * d) * c * (a * d) * 1 := by | |
intros | |
monoid | |
lemma ex5 : ∀ (a b c d : List α), a * b * c * d * [] = a * (b * c) * (1 * 1 * d) := by | |
intros | |
monoid | |
-- This fails: | |
-- lemma ex6 : ∀ (a b c : List α), a ++ b ++ c = a ++ (b ++ c) := by | |
-- intros | |
-- monoid |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment