-
-
Save kmill/9055e1fc32d75b818cbb036650d309ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
macro "exfalso" : tactic => `(apply False.elim) | |
/-- A function-backed vector -/ | |
structure Vec (ι : Type u) (α : Type v) where | |
toFun : ι → α | |
macro "vec" xs:Lean.explicitBinders " => " b:term : term => Lean.expandExplicitBinders `Vec.mk xs b | |
/-- support `v[i]` notation. -/ | |
@[inline] def Vec.getOp (self : Vec ι α) (idx : ι) : α := self.toFun idx | |
/-- A vector as a function mapping indices to values. -/ | |
class HasVec (β : Type u) (ι : outParam $ Type v) (α : outParam $ Type w) where | |
toVec : β → Vec ι α | |
export HasVec (toVec) | |
instance : HasVec (Vec ι α) ι α where | |
toVec v := v | |
instance [Add α] : Add (Vec ι α) where | |
add v w := vec i => v[i] + w[i] | |
instance [Sub α] : Sub (Vec ι α) where | |
sub v w := vec i => v[i] - w[i] | |
instance [Mul α] : Mul (Vec ι α) where | |
mul v w := vec i => v[i] * w[i] | |
instance [Div α] : Div (Vec ι α) where | |
div v w := vec i => v[i] / w[i] | |
namespace Vec | |
def push (v : Vec ι (Vec κ α)) : Vec (ι × κ) α := vec p => v[p.1][p.2] | |
def pop (v : Vec (ι × κ) α) : Vec ι (Vec κ α) := vec i j => v[(i, j)] | |
def reindex (v : Vec ι α) (f : κ → ι) : Vec κ α := vec i => v[f i] | |
instance : Monad (Vec ι) where | |
pure x := vec i => x | |
map f v := vec i => f v[i] | |
seq f v := vec i => f[i] v[i] | |
bind v f := vec i => (f v[i])[i] -- diagonal (is this actually a monad law?) | |
instance [OfNat α n] : OfNat (Vec ι α) n where | |
ofNat := vec i => OfNat.ofNat n | |
def transpose (v : Vec ι (Vec κ α)) : Vec κ (Vec ι α) := vec j i => v[i][j] | |
end Vec | |
def Function.leftInverse (g : β → α) (f : α → β) : Prop := | |
∀ x, g (f x) = x | |
def Function.rightInverse (g : β → α) (f : α → β) : Prop := | |
Function.leftInverse f g | |
structure Equiv (α : Sort u) (β : Sort v) where | |
toFun : α → β | |
invFun : β → α | |
leftInv : Function.leftInverse invFun toFun | |
rightInv : Function.rightInverse invFun toFun | |
infix:25 " ≃ " => Equiv | |
def Equiv.symm (f : α ≃ β) : β ≃ α where | |
toFun := f.invFun | |
invFun := f.toFun | |
leftInv := f.rightInv | |
rightInv := f.leftInv | |
/-- An equivalence "is" a function. -/ | |
instance (α : Type u) (β : Type v) : CoeFun (Equiv α β) (λ _ => α → β) where | |
coe := Equiv.toFun | |
class Enumerable (α : Type u) where | |
card : Nat | |
enum : α ≃ Fin card | |
section CartesianProduct | |
theorem cartEncodeProp {i j m n : Nat} (hi : i < m) (hj : j < n) : i * n + j < m * n := by | |
cases m with | |
| zero => exfalso; exact Nat.notLtZero _ hi | |
| succ m => { | |
rw Nat.succMul; | |
exact Nat.ltOfLeOfLt (Nat.addLeAddRight (Nat.mulLeMulRight _ (Nat.leOfLtSucc hi)) _) (Nat.addLtAddLeft hj _) | |
} | |
def cartDecode {n m : Nat} (k : Fin (n * m)) : Fin n × Fin m := | |
let ⟨k, h⟩ := k | |
( | |
⟨k / m, sorry⟩, | |
⟨k % m, Nat.modLt _ (by { cases m; exfalso; rw Nat.mulZero at h; exact Nat.notLtZero _ h; apply Nat.succPos})⟩ | |
) | |
instance [Enumerable α] [Enumerable β] : Enumerable (α × β) where | |
card := Enumerable.card α * Enumerable.card β | |
enum := { | |
toFun := λ (a, b) => | |
let ⟨i, hi⟩ := Enumerable.enum a | |
let ⟨j, hj⟩ := Enumerable.enum b | |
⟨i * Enumerable.card β + j, cartEncodeProp hi hj⟩ | |
invFun := λ n => | |
let (i, j) := cartDecode n | |
(Enumerable.enum.symm i, Enumerable.enum.symm j) | |
leftInv := sorry | |
rightInv := sorry | |
} | |
end CartesianProduct | |
instance : Enumerable (Fin n) where | |
card := n | |
enum := { | |
toFun := id | |
invFun := id | |
leftInv := λ _ => rfl | |
rightInv := λ _ => rfl | |
} | |
instance : Enumerable Bool where | |
card := 2 | |
enum := { | |
toFun := fun | |
| false => 0 | |
| true => 1 | |
invFun := fun | |
| ⟨0, _⟩ => false | |
| ⟨1, _⟩ => true | |
| ⟨n+2, h⟩ => False.elim (Nat.notSuccLeZero _ (Nat.leOfSuccLeSucc (Nat.leOfSuccLeSucc h))) | |
leftInv := by | |
intro | |
| true => rfl | |
| false => rfl | |
rightInv := by | |
intro | |
| ⟨0, _⟩ => rfl | |
| ⟨1, _⟩ => rfl | |
| ⟨n+2, h⟩ => exact False.elim (Nat.notSuccLeZero _ (Nat.leOfSuccLeSucc (Nat.leOfSuccLeSucc h))) | |
} | |
instance : Enumerable Empty where | |
card := 0 | |
enum := { | |
toFun := fun t => nomatch t | |
invFun := fun | |
| ⟨n, h⟩ => False.elim (Nat.notSuccLeZero _ h) | |
leftInv := fun t => nomatch t | |
rightInv := fun t => nomatch t | |
} | |
def Enumerable.listOf.aux (α : Type u) [Enumerable α] : Nat -> Nat -> List α | |
| lo, 0 => [] | |
| lo, (left+1) => | |
if h : lo < Enumerable.card α then | |
Enumerable.enum.symm ⟨lo, h⟩ :: aux α (lo + 1) left | |
else [] -- Shouldn't happen, but makes the definition easy. | |
/-- Create a list of every term in the Enumerable type in order. -/ | |
def Enumerable.listOf (α : Type u) [Enumerable α] : List α := | |
Enumerable.listOf.aux α 0 (Enumerable.card α) | |
def Vec.sum [Enumerable ι] [Add α] [OfNat α Nat.zero] (v : Vec ι α) : α := do | |
let mut s : α := 0 | |
for i in Enumerable.listOf ι do | |
s := s + v[i] | |
return s | |
structure DenseVec (ι : Type u) [Enumerable ι] (α : Type v) where | |
array : Array α | |
hasSize : array.size = Enumerable.card ι | |
namespace DenseVec | |
variables {ι : Type u} [Enumerable ι] {α : Type v} | |
def fill (a : α) : DenseVec ι α where | |
array := Array.mkArray (Enumerable.card ι) a | |
hasSize := Array.sizeMkArrayEq .. | |
def empty [Inhabited α] : DenseVec ι α := | |
fill Inhabited.default | |
def translateIdx (v : DenseVec ι α) (i : ι) : Fin v.array.size := | |
let ⟨n, h⟩ := Enumerable.enum i | |
⟨n, by rw v.hasSize; exact h⟩ | |
/-- Get the value associated to a particular index. -/ | |
def get (v : DenseVec ι α) (i : ι) : α := | |
v.array.get (v.translateIdx i) | |
/-- support `v[i]` notation. -/ | |
@[inline] def getOp (self : DenseVec ι α) (idx : ι) : α := self.get idx | |
instance : HasVec (DenseVec ι α) ι α where | |
toVec v := vec i => v[i] | |
def of [HasVec β ι α] (v : β) : DenseVec ι α where | |
array := do | |
let v' := toVec v | |
let mut a := Array.empty | |
for i in Enumerable.listOf ι do | |
a := a.push $ v'[i] | |
return a | |
hasSize := sorry -- need to define differently to be able to easily prove this | |
/-- Set the value associated to a particular index. -/ | |
def set (v : DenseVec ι α) (i : ι) (a : α) : DenseVec ι α where | |
array := v.array.set (v.translateIdx i) a | |
hasSize := by rw [Array.sizeSetEq, v.hasSize] | |
def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : DenseVec ι α) (b : β) (f : α → β → m (ForInStep β)) : m β := | |
as.array.forIn b f | |
def pure (x : α) : DenseVec ι α := fill x | |
--instance [Enumerable ι] : Monad (DenseVec ι) where | |
/- | |
def dot (v w : DenseVec ι Float) : Float := do | |
let mut s : Float := 0 | |
for i in Enumerable.listOf ι do | |
s := s + v[i] * w[i] | |
return s | |
def add (v w : DenseVec ι Float) : DenseVec ι Float := do | |
let mut v := v | |
for i in Enumerable.listOf ι do | |
v := v.set i (v[i] + w[i]) | |
return v | |
-/ | |
end DenseVec | |
theorem List.toArraySizeEq (x : List α) : x.toArray.size = x.length := sorry | |
def List.toDenseVec (x : List α) : DenseVec (Fin x.length) α where | |
array := x.toArray | |
hasSize := List.toArraySizeEq .. | |
syntax "![" sepBy(term, ", ") "]" : term | |
macro_rules | |
| `(![ $elems,* ]) => `(List.toDenseVec [ $elems,* ]) | |
example : DenseVec (Fin 3) Nat := ![2,22,222] | |
open Enumerable | |
def main : IO UInt32 := do | |
let mut v : DenseVec (Fin 3) Nat := ![2,22,222] | |
v := DenseVec.of $ toVec v + toVec v + 2 | |
for x in v do | |
IO.println s!"elt of v is {x}" | |
let s : Nat := Vec.sum $ toVec v | |
IO.println s!"sum = {s}" | |
return 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment