Skip to content

Instantly share code, notes, and snippets.

@JasonGross
Forked from digama0/fiat-test.lean
Created November 2, 2019 04:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JasonGross/b90d470627c18d1c58a60ef625e9a683 to your computer and use it in GitHub Desktop.
Save JasonGross/b90d470627c18d1c58a60ef625e9a683 to your computer and use it in GitHub Desktop.
Using simp and norm_num for big computations
import tactic.norm_num
import algebra.group_power
open prod
universes u v w ℓ
inductive let_bound (α : Type*)
| base : α → let_bound
| dlet : ℤ → (ℤ → let_bound) → let_bound
| mlet {β : Type} : β → (β → let_bound) → let_bound
open let_bound
@[simp] def let_bound.eval {α} : let_bound α → α
| (base a) := a
| (dlet x f) := (f x).eval
| (mlet x f) := (f x).eval
@[simp] def let_bound.bind {α β} (f : α → let_bound β) : let_bound α → let_bound β
| (base a) := f a
| (dlet x f) := dlet x (λ x, let_bound.bind (f x))
| (mlet x f) := mlet x (λ x, let_bound.bind (f x))
@[simp] def let_bound.map {α β} (f : α → β) : let_bound α → let_bound β
| (base a) := base (f a)
| (dlet x f) := dlet x (λ x, let_bound.map (f x))
| (mlet x f) := mlet x (λ x, let_bound.map (f x))
theorem let_bound.eval_mlet {α β a} {x y : β} {f : β → let_bound α} {e}
(r : e = mlet x f)
(h : x = y)
(H : (f y).eval = a) :
e.eval = a := by rwa [r, let_bound.eval, show x = y, from h]
def let_in {A : Type u} {B : Type v} (x : A) (f : A → B) := f x
@[simp] def let_in.lift_bind {A B C} (x : A) (f : A → list B) (g : B → list C) :
list.bind (let_in x f) g = let_in x (λ x, list.bind (f x) g) := rfl
@[simp] def let_in.lift_partition {A : Type u} {B : Type v}
(x : A) (f : A → list B) (g : B → Prop) [decidable_pred g] :
list.partition g (let_in x f) = let_in x (λ x, list.partition g (f x)) := rfl
theorem let_bound.eval_dlet {α} {x} {f : ℤ → let_bound α} {g : ℤ → α} {e}
(r : e = dlet x f)
(H : ∀ a, (f a).eval = g a) :
e.eval = let_in x g := by rw [r, let_bound.eval, H]; refl
attribute [simp] list.filter
-- start len
@[simp]
def list.seq : ℕ → ℕ → list ℕ
| _ 0 := []
| start (nat.succ len') := start :: list.seq (nat.succ start) len'
@[simp]
def list.update_nth' {T : Type u} : ∀ (n : ℕ) (f : T → T) (xs : list T), list T
| 0 f [] := []
| 0 f (x' :: xs') := f x' :: xs'
| (nat.succ n') f [] := []
| (nat.succ n') f (x' :: xs') := x' :: list.update_nth' n' f xs'
@[simp]
def list.nth_default {A : Type u} (default : A) : ∀ (ls : list A) (n : ℕ), A
| [] _ := default
| (x :: xs) 0 := x
| (x :: xs) (nat.succ n') := list.nth_default xs n'
@[simp]
def list.expand_helper {A : Type u} (f : ℕ → A) : ∀ (n : nat) (idx : nat), list A
| 0 idx := []
| (nat.succ n') idx := f idx :: list.expand_helper n' (nat.succ idx)
@[simp]
def list.expand {A} (f : ℕ → A) (n : nat) : list A := list.map f (list.seq 0 n)
@[simp] def associational.mul : ∀ (p q : list (ℕ × ℤ)), list (ℕ × ℤ)
| [] q := []
| ((a, b) :: ts) q :=
q.map (λ ⟨a', b'⟩, (a * a', b * b')) ++ associational.mul ts q
@[simp] def associational.split (s : ℕ) : list (ℕ × ℤ) → let_bound (list (ℕ × ℤ) × list (ℕ × ℤ))
| [] := base ([], [])
| ((a, b)::l) := (associational.split l).bind $ λ ⟨hi, lo⟩,
mlet (to_bool (a % s = 0)) $ λ c,
cond c
(mlet (a / s) $ λ d, base (hi, (d, b) :: lo))
(base ((a, b) :: hi, lo))
@[simp] def let_in.lift_split (s) {A : Type u} (f) (x : A) :
@associational.split._match_1 s (let_in x f) =
let_in x (λ x, associational.split._match_1 s (f x)) := rfl
@[simp] def associational.reduce (s:ℕ) (c p) : let_bound (list (ℕ × ℤ)) :=
(associational.split s p).map $ λ ⟨lo, hi⟩,
lo ++ associational.mul c hi
@[simp]
def associational.repeat_reduce :
∀ (n : nat) (s:ℕ) (c:list (ℕ × ℤ)) (p:list (ℕ × ℤ)), let_bound (list (ℕ × ℤ))
| 0 s c p := base p
| (nat.succ n') s c p := (associational.split s p).bind $ λ d,
match d with
| (lo, []) := base p
| (lo, hi) :=
let p := lo ++ associational.mul c hi in
associational.repeat_reduce n' s c p
end
@[simp]
def associational.carryterm (w fw : ℕ) : ℕ × ℤ → let_bound (list (ℕ × ℤ))
| (a, b) :=
mlet (to_bool (a = w)) $ λ c,
cond c
(dlet b $ λ t2,
dlet (t2 / fw) $ λ d2,
dlet (t2 % fw) $ λ m2,
base [(w * fw, d2), (w,m2)])
(base [(a, b)])
@[simp]
def associational.carry (w fw:ℕ) : list (ℕ × ℤ) → let_bound (list (ℕ × ℤ))
| [] := base []
| ((a, b) :: p) :=
(associational.carryterm w fw (a, b)).bind $ λ l,
(associational.carry p).map $ λ l₂, l ++ l₂
section
parameters (weight : ℕ → ℕ)
@[simp]
def positional.to_associational_aux : list ℤ → ℕ → list (ℕ × ℤ)
| [] n := []
| (x :: xs) n := (weight n, x) :: positional.to_associational_aux xs n.succ
@[simp] def let_in.lift_to_associational_aux {A : Type u} (f : A → list ℤ) (n : ℕ) (x : A) :
positional.to_associational_aux (let_in x f) n =
let_in x (λ x, positional.to_associational_aux (f x) n) := rfl
@[simp]
def positional.to_associational (xs : list ℤ) : list (ℕ × ℤ) :=
positional.to_associational_aux xs 0
@[simp]
def positional.zeros (n : ℕ) : list ℤ := list.repeat 0 n.
@[simp]
def positional.add_to_nth (i : ℕ) (x : ℤ) (ls : list ℤ) : list ℤ
:= list.update_nth' i (λ y, x + y) ls.
@[simp]
def positional.place (a : ℕ) (b : ℤ) : ∀ (i:ℕ), let_bound (ℕ × ℤ)
| 0 := base (0, a * b)
| i@(nat.succ i') :=
mlet (to_bool (a % (weight i) = 0)) $ λ v,
cond v
(mlet (a / weight i) (λ c, base (i, c * b)))
(positional.place i')
@[simp]
def positional.from_associational (n : ℕ) : list (ℕ × ℤ) → let_bound (list ℤ)
| [] := base $ positional.zeros n
| ((a, b) :: ts) :=
(positional.place a b (nat.pred n)).bind $ λ ⟨i, x⟩,
dlet x $ λ x,
(positional.from_associational ts).map (positional.add_to_nth i x)
@[simp]
def positional.extend_to_length (n_in n_out : ℕ) (p:list ℤ) : list ℤ :=
p ++ positional.zeros (n_out - n_in).
@[simp]
def positional.drop_high_to_length (n : ℕ) (p:list ℤ) : list ℤ :=
list.take n p.
section
parameters (s:ℕ)
(c:list (ℕ × ℤ))
@[simp]
def positional.mulmod (n:ℕ) (a b:list ℤ) : let_bound (list ℤ)
:= let a_a := positional.to_associational a in
let b_a := positional.to_associational b in
let ab_a := associational.mul a_a b_a in
let abm_a := associational.repeat_reduce n s c ab_a in
abm_a.bind $ positional.from_associational n
end
@[simp]
def positional.add (n:ℕ) (a b:list ℤ) : let_bound (list ℤ)
:= let a_a := positional.to_associational a in
let b_a := positional.to_associational b in
positional.from_associational n (a_a ++ b_a).
section
@[simp]
def positional.carry (n m : ℕ) (index:ℕ) (p:list ℤ) : let_bound (list ℤ) :=
let_bound.bind (positional.from_associational m) $
associational.carry (weight index)
(weight (nat.succ index) / weight index)
(positional.to_associational p)
@[simp]
def positional.carry_reduce (n : ℕ) (s:ℕ) (c:list (ℕ × ℤ))
(index:ℕ) (p : list ℤ) : let_bound (list ℤ) :=
(@positional.carry n (nat.succ n) index p).bind $ λ a,
let x := positional.to_associational a in
(associational.reduce s c x).bind $ λ e,
positional.from_associational n e
@[simp] def positional.chained_carries_aux (n s c p) : list nat → let_bound (list ℤ)
| [] := p
| (a::l) := (positional.chained_carries_aux l).bind $ positional.carry_reduce n s c a
@[simp] def positional.chained_carries (n s c p) (idxs : list nat) : let_bound (list ℤ) :=
positional.chained_carries_aux n s c p (list.reverse idxs)
@[simp]
def positional.chained_carries_no_reduce_aux (n p) : list nat → let_bound (list ℤ)
| [] := p
| (a::l) := (positional.chained_carries_no_reduce_aux l).bind $ positional.carry n n a
@[simp]
def positional.chained_carries_no_reduce (n p) (idxs : list nat) : let_bound (list ℤ) :=
positional.chained_carries_no_reduce_aux n p (list.reverse idxs)
@[simp]
def positional.encode (n s c) (x : ℤ) : let_bound (list ℤ) :=
positional.chained_carries n s c (positional.from_associational n [(1,x)]) (list.seq 0 n)
@[simp]
def positional.encode_no_reduce (n) (x : ℤ) : let_bound (list ℤ) :=
positional.chained_carries_no_reduce n (positional.from_associational n [(1,x)]) (list.seq 0 n)
end
section
parameters (n:ℕ)
(s:ℕ)
(c:list (ℕ × ℤ))
(coef:ℕ).
@[simp]
def positional.scmul (x : ℤ) (a:list ℤ) : let_bound (list ℤ)
:= let A := positional.to_associational a in
let R := associational.mul A [(1, x)] in
positional.from_associational n R.
end
end
@[simp]
def divup (a b : ℕ) : ℕ := (a + b - 1) / b
-- := 2^(int.to_nat (-(-(limbwidth_num * i) / limbwidth_den))).
section
open positional
parameters (limbwidth_num limbwidth_den : ℕ)
(s : ℕ)
(c : list (ℕ × ℤ))
(n : ℕ)
(len_c : ℕ)
(idxs : list ℕ)
@[simp]
def modops.weight (i : ℕ) : ℕ
:= 2^(divup (limbwidth_num * i) limbwidth_den)
@[simp]
def modops.carry_mulmod (f g : list ℤ) : let_bound (list ℤ) :=
chained_carries modops.weight n s c (mulmod modops.weight s c n f g) idxs
@[simp]
def modops.carry_scmulmod (x : ℤ) (f : list ℤ) : let_bound (list ℤ) :=
(encode modops.weight n s c x).bind $ λ e,
chained_carries modops.weight n s c (mulmod modops.weight s c n e f) idxs
@[simp]
def modops.carrymod (f : list ℤ) : let_bound (list ℤ) :=
chained_carries modops.weight n s c (base f) idxs
@[simp]
def modops.addmod (f g : list ℤ) : let_bound (list ℤ) :=
add modops.weight n f g
@[simp]
def modops.encodemod (f : ℤ) : let_bound (list ℤ) :=
encode modops.weight n s c f
end
def let_in.lift {A : Type u} {B : Type v} {C : Type w} (F : B → C) (x : A) (f : A → B) : F (let_in x f) = let_in x (λ y, F (f y)) := rfl
def let_in.lift_zip2 {A : Type u} {B : Type v} {C : Type w} (ls : list C) (x : A) (f : A → list B) : list.zip ls (let_in x f) = let_in x (λ y, list.zip ls (f y)) := let_in.lift _ _ _.
def let_in.lift_append1 {A : Type u} {B : Type v} (x : A) (f : A → list B) (ls : list B) : list.append (let_in x f) ls = let_in x (λ y, list.append (f y) ls) := rfl
def let_in.lift_append2 {A : Type u} {B : Type v} (x : A) (f : A → list B) (ls : list B) : list.append ls (let_in x f) = let_in x (λ y, list.append ls (f y)) := rfl
def let_in.lift_foldr {A : Type u} {B : Type v} {C : Type w} (x : A) (f : A → list B) (g : B → C → C) (init : C) : list.foldr g init (let_in x f) = let_in x (λ x, list.foldr g init (f x)) := rfl
def let_in.lift_map {A : Type u} {B : Type v} {C : Type w} (x : A) (f : A → list B) (g : B → C) : list.map g (let_in x f) = let_in x (λ x, list.map g (f x)) := rfl
def let_in.lift_join {A : Type u} {B : Type v} (x : A) (f : A → list (list B)) : list.join (let_in x f) = let_in x (λ x, list.join (f x)) := rfl
@[simp] def let_in.lift_update_nth' {A : Type u} {B : Type v} (x : A) (f : A → list B) (g : B → B) (n : ℕ) : list.update_nth' n g (let_in x f) = let_in x (λ x, list.update_nth' n g (f x)) := rfl
@[simp] def let_in.split_pair {A : Type u} {A' : Type w} {B : Type v} (x : A) (y : A') (f : A × A' → B) : let_in (x, y) f = let_in x (λ x, let_in y (λ y, f (x, y))) := rfl
@[simp] def let_in.lift_nat.zero {A : Type v} (f : ℕ → A) : let_in 0 f = f 0 := rfl
@[simp] def let_in.lift_nat.one {A : Type v} (f : ℕ → A) : let_in 1 f = f 1 := rfl
@[simp]
def ex.n : ℕ := 1 -- 5
@[simp]
def ex.s : ℕ := 2^16 -- 2^255
@[simp]
def ex.c : list (ℕ × ℤ) := [(1, 1)] -- [(1, 19)]
@[simp]
def ex.idxs : list ℕ := [0, 1] -- [0, 1, 2, 3, 4, 0, 1]
@[simp]
def ex.machine_wordsize : ℕ := 8 -- 64
@[simp]
def ex2.n : ℕ := 5
@[simp]
def ex2.s : ℕ := 2^255
@[simp]
def ex2.c : list (ℕ × ℤ) := [(1, 19)]
@[simp]
def ex2.idxs : list ℕ := [0, 1, 2, 3, 4, 0, 1]
@[simp]
def ex2.machine_wordsize : ℕ := 64
-- local notation `dlet` binders ` ≔ ` b ` in ` c:(scoped P, P) := let_in b c
set_option pp.max_depth 10
-- set_option pp.max_steps 1000000000
--set_option pp.numerals false
-- set_option pp.all true
open modops
open ex
-- set_option trace.simplify.rewrite true
example (f g : ℕ → ℤ) :
(carry_mulmod machine_wordsize 1 s c n idxs (list.expand f n) (list.expand g n)).eval = sorry :=
begin
conv {
to_lhs,
(tactic.repeat $ do
t@`(let_bound.eval %%e) ← conv.lhs,
e' ← tactic.whnf e,
tactic.mk_app ``let_bound.eval [e'] >>= conv.change,
-- tactic.trace (e', e),
match e' with
| `(let_bound.mlet %%e₁ _) := do
tactic.trace e₁,
`[refine let_bound.eval_mlet rfl (by {
transitivity, simp [nat.succ_eq_add_one]; refl, norm_num; refl}) _],
tactic.get_goals >>= tactic.set_goals ∘ list.tail
| `(let_bound.dlet _ _) := do
`[apply let_bound.eval_dlet rfl, intro, trace_state]
| _ := do tactic.trace e', tactic.failed
end),
},
-- refine let_bound.eval_dlet _ _, swap 4, convert @eq.refl (let_bound (list int)) (dlet _ _), simp, intro
-- norm_num [(∘), nat.succ_eq_add_one, -add_comm, -list.partition_eq_filter_filter, list.enum,
-- list.enum_from, list.partition, -list.reverse_cons, list.reverse, list.reverse_core],
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment