Skip to content

Instantly share code, notes, and snippets.

@mir-ikbch
Created November 29, 2019 05:02
Show Gist options
  • Save mir-ikbch/6df11b5ba869c1d17fed5681104f5eb5 to your computer and use it in GitHub Desktop.
Save mir-ikbch/6df11b5ba869c1d17fed5681104f5eb5 to your computer and use it in GitHub Desktop.
Require Import List PeanoNat Arith Morphisms Setoid.
Import ListNotations.
Set Implicit Arguments.
Ltac optimize db :=
eexists; intros;
match goal with
|- ?X = ?Y =>
let P := fresh in
set (P := fun y => y = Y);
enough (P X) by auto
end;
repeat autounfold with db;
autorewrite with db;
reflexivity.
Definition sqr_list xs := map (fun x => x * x) xs.
Definition sqr_add2_list xs := map (fun x => x + 2) (sqr_list xs).
Hint Unfold sqr_list sqr_add2_list : my_db.
Hint Rewrite map_map : my_db.
Definition sa_optimized_sig : { f | forall xs, sqr_add2_list xs = f xs } :=
ltac:(optimize my_db).
Eval simpl in proj1_sig sa_optimized_sig.
(* result :
[[[
= map (fun x : nat => x * x + 2)
: list nat -> list nat
]]]
*)
Lemma fold_left_map : forall (A B C : Type)(f : A -> B) g l (init:C),
fold_left g (map f l) init = fold_left (fun x y => g x (f y)) l init.
Proof.
induction l; simpl; auto.
Qed.
Definition sum xs := fold_left Nat.add xs 0.
Definition sum_of_square xs := sum (map (fun x => x * x) xs).
Hint Unfold sum sum_of_square : my_db.
Hint Rewrite fold_left_map map_map : my_db.
Definition ss_optimized_sig : { prog | forall xs, sum_of_square xs = prog xs }
:= ltac:(optimize my_db).
Eval simpl in proj1_sig ss_optimized_sig.
(* result :
[[[
= fun xs : list nat => fold_left (fun x y : nat => x + y * y) xs 0
: list nat -> nat
]]]
*)
Require Import ZArith.
Open Scope Z_scope.
Fixpoint divide_list (A : Type)(xs : list A) :=
match xs with
| [] => ([],[])
| x::xs' =>
let (xs0,xs1) := divide_list xs' in
(x::xs1,xs0)
end.
Fixpoint map2 (A B C : Type)(f : A -> B -> C)(xs : list A)(ys : list B) :=
match xs,ys with
| [],_ => []
| _,[] => []
| x::xs',y::ys' => f x y :: map2 f xs' ys'
end.
Section Fft.
Variable N : Z.
Definition modulus := 2 ^ (2 ^ N) + 1.
Definition add x y :=
(x + y) mod modulus.
Definition sub x y :=
(x - y) mod modulus.
Definition mul x y :=
(x * y) mod modulus.
Definition pow_pos x :=
Pos.iter (mul x) 1.
Definition pow x y :=
match y with
| 0 => 1
| Z.pos p => pow_pos x p
| Z.neg _ => 0
end.
Definition prt n : Z :=
pow 2 (pow 2 (N + 1 - n)).
Definition pow_neg1 (n : Z) :=
if Z.even n then 1 else -1.
Definition shiftl x y :=
Z.of_nat (Nat.shiftl (Z.to_nat (x mod modulus)) (Z.to_nat y)) mod modulus.
Axiom shiftl_mul_pow2 : forall x y,
mul (pow 2 x) y = shiftl y x.
Axiom pow_pow : forall x y z,
pow (pow x y) z = pow x (mul y z).
Fixpoint Zseq x n :=
match n with
| O => []
| S n' => x :: Zseq (1 + x) n'
end.
Definition Zrange x y :=
Zseq x (Z.to_nat (1 + y - x)).
Notation "[ x ;..; y ]" := (Zrange x y) (at level 0).
Lemma map2_ext : forall A B C (f g : A -> B -> C) xs ys,
(forall a b, f a b = g a b) ->
map2 f xs ys = map2 g xs ys.
Proof.
induction xs; simpl; intros.
auto.
destruct ys.
auto.
rewrite H.
f_equal.
auto.
Qed.
Instance map2_morphism A B C:
Proper (pointwise_relation _ (pointwise_relation _ eq) ==> @eq (list A) ==> @eq (list B) ==> @eq (list C)) (@map2 A B C).
Proof.
unfold Proper, respectful, pointwise_relation.
intros.
subst.
apply map2_ext; auto.
Qed.
Lemma map2_map : forall A A' B C (f : A -> B -> C) (g : A' -> A) xs ys,
map2 f (map g xs) ys = map2 (fun x y => f (g x) y) xs ys.
Proof.
induction xs; simpl; intros;
destruct ys; congruence.
Qed.
Definition fft_aux xs n :=
let omegas :=
map (fun k => pow (prt (n + 1)) k) [0 ;..; 2 ^ n]
in
map2 mul omegas xs.
Ltac setoid_autorewrite' rew :=
lazymatch rew with
| (?rew',?lem',?lem) => setoid_rewrite lem + setoid_autorewrite' (rew',lem')
| (?lem',?lem) => setoid_rewrite lem || setoid_rewrite lem'
end.
Ltac setoid_autorewrite rew := repeat setoid_autorewrite' rew.
Ltac optimize' db rew :=
eexists; intros;
match goal with
|- ?X = ?Y =>
let P := fresh in
set (P := fun y => y = Y);
enough (P X) by auto
end;
repeat autounfold with db;
setoid_autorewrite rew;
reflexivity.
Hint Unfold fft_aux prt : my_db.
Definition fft_aux_sig :
{ f | forall xs n, fft_aux xs n = f xs n }.
Proof.
optimize' my_db (shiftl_mul_pow2,pow_pow,map2_map).
Defined.
Eval simpl in proj1_sig fft_aux_sig.
(* result :
[[[
= fun (xs : list Z) (n : Z) =>
map2 (fun x y : Z => shiftl y (shiftl x (N + 1 - (n + 1)))) [0;..; 2 ^ n] xs
: list Z -> Z -> list Z
]]]
*)
Function fft (xs : list Z) (n:nat) :=
match n with
| O => xs
| S n' =>
let (xs0,xs1) := divide_list xs in
let xs0' := fft xs0 n' in
let xs1' := proj1_sig fft_aux_sig (fft xs1 n') (Z.of_nat n') in
let xs' := map2 add xs0' xs1' in
let xs'' := map2 sub xs0' xs1' in
xs' ++ xs''
end.
End Fft.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment