-
-
Save JasonGross/b90d470627c18d1c58a60ef625e9a683 to your computer and use it in GitHub Desktop.
Using simp and norm_num for big computations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tactic.norm_num | |
import algebra.group_power | |
open prod | |
universes u v w ℓ | |
inductive let_bound (α : Type*) | |
| base : α → let_bound | |
| dlet : ℤ → (ℤ → let_bound) → let_bound | |
| mlet {β : Type} : β → (β → let_bound) → let_bound | |
open let_bound | |
@[simp] def let_bound.eval {α} : let_bound α → α | |
| (base a) := a | |
| (dlet x f) := (f x).eval | |
| (mlet x f) := (f x).eval | |
@[simp] def let_bound.bind {α β} (f : α → let_bound β) : let_bound α → let_bound β | |
| (base a) := f a | |
| (dlet x f) := dlet x (λ x, let_bound.bind (f x)) | |
| (mlet x f) := mlet x (λ x, let_bound.bind (f x)) | |
@[simp] def let_bound.map {α β} (f : α → β) : let_bound α → let_bound β | |
| (base a) := base (f a) | |
| (dlet x f) := dlet x (λ x, let_bound.map (f x)) | |
| (mlet x f) := mlet x (λ x, let_bound.map (f x)) | |
theorem let_bound.eval_mlet {α β a} {x y : β} {f : β → let_bound α} {e} | |
(r : e = mlet x f) | |
(h : x = y) | |
(H : (f y).eval = a) : | |
e.eval = a := by rwa [r, let_bound.eval, show x = y, from h] | |
def let_in {A : Type u} {B : Type v} (x : A) (f : A → B) := f x | |
@[simp] def let_in.lift_bind {A B C} (x : A) (f : A → list B) (g : B → list C) : | |
list.bind (let_in x f) g = let_in x (λ x, list.bind (f x) g) := rfl | |
@[simp] def let_in.lift_partition {A : Type u} {B : Type v} | |
(x : A) (f : A → list B) (g : B → Prop) [decidable_pred g] : | |
list.partition g (let_in x f) = let_in x (λ x, list.partition g (f x)) := rfl | |
theorem let_bound.eval_dlet {α} {x} {f : ℤ → let_bound α} {g : ℤ → α} {e} | |
(r : e = dlet x f) | |
(H : ∀ a, (f a).eval = g a) : | |
e.eval = let_in x g := by rw [r, let_bound.eval, H]; refl | |
attribute [simp] list.filter | |
-- start len | |
@[simp] | |
def list.seq : ℕ → ℕ → list ℕ | |
| _ 0 := [] | |
| start (nat.succ len') := start :: list.seq (nat.succ start) len' | |
@[simp] | |
def list.update_nth' {T : Type u} : ∀ (n : ℕ) (f : T → T) (xs : list T), list T | |
| 0 f [] := [] | |
| 0 f (x' :: xs') := f x' :: xs' | |
| (nat.succ n') f [] := [] | |
| (nat.succ n') f (x' :: xs') := x' :: list.update_nth' n' f xs' | |
@[simp] | |
def list.nth_default {A : Type u} (default : A) : ∀ (ls : list A) (n : ℕ), A | |
| [] _ := default | |
| (x :: xs) 0 := x | |
| (x :: xs) (nat.succ n') := list.nth_default xs n' | |
@[simp] | |
def list.expand_helper {A : Type u} (f : ℕ → A) : ∀ (n : nat) (idx : nat), list A | |
| 0 idx := [] | |
| (nat.succ n') idx := f idx :: list.expand_helper n' (nat.succ idx) | |
@[simp] | |
def list.expand {A} (f : ℕ → A) (n : nat) : list A := list.map f (list.seq 0 n) | |
@[simp] def associational.mul : ∀ (p q : list (ℕ × ℤ)), list (ℕ × ℤ) | |
| [] q := [] | |
| ((a, b) :: ts) q := | |
q.map (λ ⟨a', b'⟩, (a * a', b * b')) ++ associational.mul ts q | |
@[simp] def associational.split (s : ℕ) : list (ℕ × ℤ) → let_bound (list (ℕ × ℤ) × list (ℕ × ℤ)) | |
| [] := base ([], []) | |
| ((a, b)::l) := (associational.split l).bind $ λ ⟨hi, lo⟩, | |
mlet (to_bool (a % s = 0)) $ λ c, | |
cond c | |
(mlet (a / s) $ λ d, base (hi, (d, b) :: lo)) | |
(base ((a, b) :: hi, lo)) | |
@[simp] def let_in.lift_split (s) {A : Type u} (f) (x : A) : | |
@associational.split._match_1 s (let_in x f) = | |
let_in x (λ x, associational.split._match_1 s (f x)) := rfl | |
@[simp] def associational.reduce (s:ℕ) (c p) : let_bound (list (ℕ × ℤ)) := | |
(associational.split s p).map $ λ ⟨lo, hi⟩, | |
lo ++ associational.mul c hi | |
@[simp] | |
def associational.repeat_reduce : | |
∀ (n : nat) (s:ℕ) (c:list (ℕ × ℤ)) (p:list (ℕ × ℤ)), let_bound (list (ℕ × ℤ)) | |
| 0 s c p := base p | |
| (nat.succ n') s c p := (associational.split s p).bind $ λ d, | |
match d with | |
| (lo, []) := base p | |
| (lo, hi) := | |
let p := lo ++ associational.mul c hi in | |
associational.repeat_reduce n' s c p | |
end | |
@[simp] | |
def associational.carryterm (w fw : ℕ) : ℕ × ℤ → let_bound (list (ℕ × ℤ)) | |
| (a, b) := | |
mlet (to_bool (a = w)) $ λ c, | |
cond c | |
(dlet b $ λ t2, | |
dlet (t2 / fw) $ λ d2, | |
dlet (t2 % fw) $ λ m2, | |
base [(w * fw, d2), (w,m2)]) | |
(base [(a, b)]) | |
@[simp] | |
def associational.carry (w fw:ℕ) : list (ℕ × ℤ) → let_bound (list (ℕ × ℤ)) | |
| [] := base [] | |
| ((a, b) :: p) := | |
(associational.carryterm w fw (a, b)).bind $ λ l, | |
(associational.carry p).map $ λ l₂, l ++ l₂ | |
section | |
parameters (weight : ℕ → ℕ) | |
@[simp] | |
def positional.to_associational_aux : list ℤ → ℕ → list (ℕ × ℤ) | |
| [] n := [] | |
| (x :: xs) n := (weight n, x) :: positional.to_associational_aux xs n.succ | |
@[simp] def let_in.lift_to_associational_aux {A : Type u} (f : A → list ℤ) (n : ℕ) (x : A) : | |
positional.to_associational_aux (let_in x f) n = | |
let_in x (λ x, positional.to_associational_aux (f x) n) := rfl | |
@[simp] | |
def positional.to_associational (xs : list ℤ) : list (ℕ × ℤ) := | |
positional.to_associational_aux xs 0 | |
@[simp] | |
def positional.zeros (n : ℕ) : list ℤ := list.repeat 0 n. | |
@[simp] | |
def positional.add_to_nth (i : ℕ) (x : ℤ) (ls : list ℤ) : list ℤ | |
:= list.update_nth' i (λ y, x + y) ls. | |
@[simp] | |
def positional.place (a : ℕ) (b : ℤ) : ∀ (i:ℕ), let_bound (ℕ × ℤ) | |
| 0 := base (0, a * b) | |
| i@(nat.succ i') := | |
mlet (to_bool (a % (weight i) = 0)) $ λ v, | |
cond v | |
(mlet (a / weight i) (λ c, base (i, c * b))) | |
(positional.place i') | |
@[simp] | |
def positional.from_associational (n : ℕ) : list (ℕ × ℤ) → let_bound (list ℤ) | |
| [] := base $ positional.zeros n | |
| ((a, b) :: ts) := | |
(positional.place a b (nat.pred n)).bind $ λ ⟨i, x⟩, | |
dlet x $ λ x, | |
(positional.from_associational ts).map (positional.add_to_nth i x) | |
@[simp] | |
def positional.extend_to_length (n_in n_out : ℕ) (p:list ℤ) : list ℤ := | |
p ++ positional.zeros (n_out - n_in). | |
@[simp] | |
def positional.drop_high_to_length (n : ℕ) (p:list ℤ) : list ℤ := | |
list.take n p. | |
section | |
parameters (s:ℕ) | |
(c:list (ℕ × ℤ)) | |
@[simp] | |
def positional.mulmod (n:ℕ) (a b:list ℤ) : let_bound (list ℤ) | |
:= let a_a := positional.to_associational a in | |
let b_a := positional.to_associational b in | |
let ab_a := associational.mul a_a b_a in | |
let abm_a := associational.repeat_reduce n s c ab_a in | |
abm_a.bind $ positional.from_associational n | |
end | |
@[simp] | |
def positional.add (n:ℕ) (a b:list ℤ) : let_bound (list ℤ) | |
:= let a_a := positional.to_associational a in | |
let b_a := positional.to_associational b in | |
positional.from_associational n (a_a ++ b_a). | |
section | |
@[simp] | |
def positional.carry (n m : ℕ) (index:ℕ) (p:list ℤ) : let_bound (list ℤ) := | |
let_bound.bind (positional.from_associational m) $ | |
associational.carry (weight index) | |
(weight (nat.succ index) / weight index) | |
(positional.to_associational p) | |
@[simp] | |
def positional.carry_reduce (n : ℕ) (s:ℕ) (c:list (ℕ × ℤ)) | |
(index:ℕ) (p : list ℤ) : let_bound (list ℤ) := | |
(@positional.carry n (nat.succ n) index p).bind $ λ a, | |
let x := positional.to_associational a in | |
(associational.reduce s c x).bind $ λ e, | |
positional.from_associational n e | |
@[simp] def positional.chained_carries_aux (n s c p) : list nat → let_bound (list ℤ) | |
| [] := p | |
| (a::l) := (positional.chained_carries_aux l).bind $ positional.carry_reduce n s c a | |
@[simp] def positional.chained_carries (n s c p) (idxs : list nat) : let_bound (list ℤ) := | |
positional.chained_carries_aux n s c p (list.reverse idxs) | |
@[simp] | |
def positional.chained_carries_no_reduce_aux (n p) : list nat → let_bound (list ℤ) | |
| [] := p | |
| (a::l) := (positional.chained_carries_no_reduce_aux l).bind $ positional.carry n n a | |
@[simp] | |
def positional.chained_carries_no_reduce (n p) (idxs : list nat) : let_bound (list ℤ) := | |
positional.chained_carries_no_reduce_aux n p (list.reverse idxs) | |
@[simp] | |
def positional.encode (n s c) (x : ℤ) : let_bound (list ℤ) := | |
positional.chained_carries n s c (positional.from_associational n [(1,x)]) (list.seq 0 n) | |
@[simp] | |
def positional.encode_no_reduce (n) (x : ℤ) : let_bound (list ℤ) := | |
positional.chained_carries_no_reduce n (positional.from_associational n [(1,x)]) (list.seq 0 n) | |
end | |
section | |
parameters (n:ℕ) | |
(s:ℕ) | |
(c:list (ℕ × ℤ)) | |
(coef:ℕ). | |
@[simp] | |
def positional.scmul (x : ℤ) (a:list ℤ) : let_bound (list ℤ) | |
:= let A := positional.to_associational a in | |
let R := associational.mul A [(1, x)] in | |
positional.from_associational n R. | |
end | |
end | |
@[simp] | |
def divup (a b : ℕ) : ℕ := (a + b - 1) / b | |
-- := 2^(int.to_nat (-(-(limbwidth_num * i) / limbwidth_den))). | |
section | |
open positional | |
parameters (limbwidth_num limbwidth_den : ℕ) | |
(s : ℕ) | |
(c : list (ℕ × ℤ)) | |
(n : ℕ) | |
(len_c : ℕ) | |
(idxs : list ℕ) | |
@[simp] | |
def modops.weight (i : ℕ) : ℕ | |
:= 2^(divup (limbwidth_num * i) limbwidth_den) | |
@[simp] | |
def modops.carry_mulmod (f g : list ℤ) : let_bound (list ℤ) := | |
chained_carries modops.weight n s c (mulmod modops.weight s c n f g) idxs | |
@[simp] | |
def modops.carry_scmulmod (x : ℤ) (f : list ℤ) : let_bound (list ℤ) := | |
(encode modops.weight n s c x).bind $ λ e, | |
chained_carries modops.weight n s c (mulmod modops.weight s c n e f) idxs | |
@[simp] | |
def modops.carrymod (f : list ℤ) : let_bound (list ℤ) := | |
chained_carries modops.weight n s c (base f) idxs | |
@[simp] | |
def modops.addmod (f g : list ℤ) : let_bound (list ℤ) := | |
add modops.weight n f g | |
@[simp] | |
def modops.encodemod (f : ℤ) : let_bound (list ℤ) := | |
encode modops.weight n s c f | |
end | |
def let_in.lift {A : Type u} {B : Type v} {C : Type w} (F : B → C) (x : A) (f : A → B) : F (let_in x f) = let_in x (λ y, F (f y)) := rfl | |
def let_in.lift_zip2 {A : Type u} {B : Type v} {C : Type w} (ls : list C) (x : A) (f : A → list B) : list.zip ls (let_in x f) = let_in x (λ y, list.zip ls (f y)) := let_in.lift _ _ _. | |
def let_in.lift_append1 {A : Type u} {B : Type v} (x : A) (f : A → list B) (ls : list B) : list.append (let_in x f) ls = let_in x (λ y, list.append (f y) ls) := rfl | |
def let_in.lift_append2 {A : Type u} {B : Type v} (x : A) (f : A → list B) (ls : list B) : list.append ls (let_in x f) = let_in x (λ y, list.append ls (f y)) := rfl | |
def let_in.lift_foldr {A : Type u} {B : Type v} {C : Type w} (x : A) (f : A → list B) (g : B → C → C) (init : C) : list.foldr g init (let_in x f) = let_in x (λ x, list.foldr g init (f x)) := rfl | |
def let_in.lift_map {A : Type u} {B : Type v} {C : Type w} (x : A) (f : A → list B) (g : B → C) : list.map g (let_in x f) = let_in x (λ x, list.map g (f x)) := rfl | |
def let_in.lift_join {A : Type u} {B : Type v} (x : A) (f : A → list (list B)) : list.join (let_in x f) = let_in x (λ x, list.join (f x)) := rfl | |
@[simp] def let_in.lift_update_nth' {A : Type u} {B : Type v} (x : A) (f : A → list B) (g : B → B) (n : ℕ) : list.update_nth' n g (let_in x f) = let_in x (λ x, list.update_nth' n g (f x)) := rfl | |
@[simp] def let_in.split_pair {A : Type u} {A' : Type w} {B : Type v} (x : A) (y : A') (f : A × A' → B) : let_in (x, y) f = let_in x (λ x, let_in y (λ y, f (x, y))) := rfl | |
@[simp] def let_in.lift_nat.zero {A : Type v} (f : ℕ → A) : let_in 0 f = f 0 := rfl | |
@[simp] def let_in.lift_nat.one {A : Type v} (f : ℕ → A) : let_in 1 f = f 1 := rfl | |
@[simp] | |
def ex.n : ℕ := 1 -- 5 | |
@[simp] | |
def ex.s : ℕ := 2^16 -- 2^255 | |
@[simp] | |
def ex.c : list (ℕ × ℤ) := [(1, 1)] -- [(1, 19)] | |
@[simp] | |
def ex.idxs : list ℕ := [0, 1] -- [0, 1, 2, 3, 4, 0, 1] | |
@[simp] | |
def ex.machine_wordsize : ℕ := 8 -- 64 | |
@[simp] | |
def ex2.n : ℕ := 5 | |
@[simp] | |
def ex2.s : ℕ := 2^255 | |
@[simp] | |
def ex2.c : list (ℕ × ℤ) := [(1, 19)] | |
@[simp] | |
def ex2.idxs : list ℕ := [0, 1, 2, 3, 4, 0, 1] | |
@[simp] | |
def ex2.machine_wordsize : ℕ := 64 | |
-- local notation `dlet` binders ` ≔ ` b ` in ` c:(scoped P, P) := let_in b c | |
set_option pp.max_depth 10 | |
-- set_option pp.max_steps 1000000000 | |
--set_option pp.numerals false | |
-- set_option pp.all true | |
open modops | |
open ex | |
-- set_option trace.simplify.rewrite true | |
example (f g : ℕ → ℤ) : | |
(carry_mulmod machine_wordsize 1 s c n idxs (list.expand f n) (list.expand g n)).eval = sorry := | |
begin | |
conv { | |
to_lhs, | |
(tactic.repeat $ do | |
t@`(let_bound.eval %%e) ← conv.lhs, | |
e' ← tactic.whnf e, | |
tactic.mk_app ``let_bound.eval [e'] >>= conv.change, | |
-- tactic.trace (e', e), | |
match e' with | |
| `(let_bound.mlet %%e₁ _) := do | |
tactic.trace e₁, | |
`[refine let_bound.eval_mlet rfl (by { | |
transitivity, simp [nat.succ_eq_add_one]; refl, norm_num; refl}) _], | |
tactic.get_goals >>= tactic.set_goals ∘ list.tail | |
| `(let_bound.dlet _ _) := do | |
`[apply let_bound.eval_dlet rfl, intro, trace_state] | |
| _ := do tactic.trace e', tactic.failed | |
end), | |
}, | |
-- refine let_bound.eval_dlet _ _, swap 4, convert @eq.refl (let_bound (list int)) (dlet _ _), simp, intro | |
-- norm_num [(∘), nat.succ_eq_add_one, -add_comm, -list.partition_eq_filter_filter, list.enum, | |
-- list.enum_from, list.partition, -list.reverse_cons, list.reverse, list.reverse_core], | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment