Skip to content

Instantly share code, notes, and snippets.

@kmill
Created December 4, 2020 01:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kmill/5eb83cb5fd102e4242c2c8a53224ba13 to your computer and use it in GitHub Desktop.
Save kmill/5eb83cb5fd102e4242c2c8a53224ba13 to your computer and use it in GitHub Desktop.
-- switched to using cartesian products for packed indices
-- implemented slicing (see `extraction`)
import control.applicative
import tactic
import data.real.basic
import algebra.module.basic
universes u v
def vec (ι : Type*) (α : Type*) := ι → α
infixr ` => `:30 := vec
def to_vec {ι : Type*} {α : Type*} (f : ι → α) : ι => α := f
def vec.push {ι₁ ι₂ α : Type*} (v : ι₁ => ι₂ => α) : ι₁ × ι₂ => α :=
λ p, v p.1 p.2
def vec.pop {ι₁ ι₂ α : Type*} (v : ι₁ × ι₂ => α) : ι₁ => ι₂ => α :=
λ i j, v ⟨i, j⟩
-- unnecessary since v() is the same thing.
--def vec.scalar {α : Type*} (v : punit => α) : α := v ()
section joining_index_types
/-! Code put indices into normal form with cartesian products -/
class has_join (α : Type*) (ι : out_param $ Type*) (α' : out_param $ Type*) :=
(join : α → ι => α')
@[priority 1200]
instance has_join.reassoc {ι ι' ι'' α : Type*} : has_join ((ι × ι') × ι'' => α) (ι × ι' × ι'') α :=
{ join := λ v p, v ((equiv.prod_assoc _ _ _).symm p) }
@[priority 1100]
instance has_join.recurse_p {ι ι' ι'' α α' : Type*} [has_join (ι' => α) ι'' α'] :
has_join (ι × ι' => α) (ι × ι'') α' :=
{ join := λ v p, has_join.join (v.pop p.1) p.2 }
@[priority 1000]
instance has_join.recurse {ι ι' α α' : Type*} [has_join α ι' α'] : has_join (ι => α) (ι × ι') α' :=
{ join := λ v p, has_join.join (v p.1) p.2 }
@[priority 900]
instance has_join.base₁ {ι α : Type*} : has_join (ι => α) ι α :=
{ join := id }
/-- Join all the indices into a single right-associated cartesian product. --/
def vec.join {α ι α'} [has_join α ι α'] : α → ι => α' := has_join.join
-- examples
--variables (a b c d e : Type) (v : a×b×a=>c×d=>e) (w : (a×b)×c=>d) (u : a=>b=>c) (i : a) (j : b)
--#check v.join
-- gives vec.join v : a × b × a × c × d => e
end joining_index_types
-- hom(-, α) functor
def vec.reindex {ι ι' α : Type*} (v : ι => α) (f : ι' → ι) : ι' => α :=
λ i, v (f i)
section finite_vectors
def tensor.empty {α : Type} : fin 0 => α := fin_zero_elim
def tensor.cons {α : Type} {n : ℕ} (h : α) (t : fin n => α) : fin n.succ => α := fin.cons h t
notation `![` l:(foldr `, ` (h t, tensor.cons h t) tensor.empty `]`) := l
end finite_vectors
notation `for` binders `, ` r:(scoped f, to_vec f) := r
section monad
instance (ι : Type*) : monad (vec ι) :=
{ map := λ α β f v, for i, f (v i),
pure := λ α x, for i, x,
seq := λ α β f v, for i, f i (v i),
bind := λ α β v f, for i, (f (v i)) i }
abbreviation vec.lift {ι : Type*} {α : Type*} (v : α) : ι => α := pure v
end monad
section extraction
/-! code for indexing and slicing simultaneously -/
inductive part (ι : Type u) : Type u
| slice : part
| index (i : ι) : part
open part(slice)
class to_part (α : Type u) (ι : Type u) :=
(coe : α → part ι)
attribute [reducible] to_part.coe
instance part.from_unit {ι : Type*} : to_part punit ι :=
{ coe := λ _, part.slice }
instance part.from_part {ι : Type*} : to_part (part ι) ι :=
{ coe := id }
instance part.from_type {ι : Type*} : to_part ι ι :=
{ coe := part.index }
@[reducible]
def part.type {ι : Type u} : part ι → Type u → Type u
| slice α := ι => α
| (part.index i) α := α
@[reducible]
def vec.extract' {ι α β : Type*} (f : α → β) : Π (p : part ι), (ι => α) → part.type p β
| slice v := (f <$> v : ι => β)
| (part.index i) v := f (v i)
notation x `⟦` l:(foldr `, ` (h t, vec.extract' t (to_part.coe h)) id `⟧`) := l x
-- Example
--variables (a b c : Type) (u : a=>b=>c) (i : a) (j : b)
--#reduce u⟦i⟧
-- gives u i
--#reduce u⟦(),j⟧
-- gives λ (i : a), u i j
--#reduce u⟦i,()⟧
-- gives λ (i_1 : b), u i i_1
end extraction
def transpose {r s α : Type*} (v : r => s => α) : s => r => α :=
for i j, v j i
def addvec {r α : Type*} [add_monoid α] (v w : r => α) : r => α :=
for i, v i + w i
def addvec' {r α : Type u} [add_monoid α] (v w : r => α) : r => α :=
(+) <$> v <*> w
def vsum {ι α : Type u} [fintype ι] [add_comm_monoid α] (v : ι => α) : α :=
finset.sum finset.univ v
def matvec {r s α : Type*} [fintype s] [comm_ring α] (A : r => s => α) (v : s => α) : r => α :=
for i, vsum for j, A i j * v j
def matvec' {r s α : Type*} [fintype s] [comm_ring α] (A : r × s => α) (v : s => α) : r => α :=
for i, vsum for j, A (i, j) * v j
def matvec'' {r s α : Type*} [fintype s] [comm_ring α] (A : r × s => α) (v : s => α) : r => α :=
for i, vsum for j, A.pop i j * v j
def matvec''' {r s α : Type*} [fintype s] [comm_ring α] (A : r × s => α) (v : s => α) : r => α :=
vsum <$> ((*) <$> A <*> v.lift.push).pop
def contract {r r' s : Type*} {α : Type*} [fintype s] [comm_ring α]
(T : r=>s=>α) (T' : s=>r'=>α) : r=>r'=>α :=
for i i', vsum ((*) <$> T i <*> T'⟦(),i'⟧)
def contract' {r r' r'' s : Type*} {α : Type*} [fintype s] [comm_ring α]
(T : r=>s=>r'=>s=>r''=>α) : r=>r'=>r''=>α :=
for i j k, vsum for l, T i l j l k
def contract'' {r r' s : Type*} {α : Type*} [fintype s] [comm_ring α]
(T : s=>r=>s=>r'=>α) : r=>r'=>α :=
for j k, vsum for l, T l j l k
def trace' {s α : Type*} [fintype s] [comm_ring α] (T : s=>s=>α) : α :=
vsum for i, T i i
def vec.prod {r r' : Type*} {α : Type*} [monoid α] (T : r=>α) (T' : r'=>α) : r=>r'=>α :=
for i i', T i * T' i'
noncomputable theory
--set_option trace.class_instances true
def Vec (n) := fin n=>ℝ
def Mat (n m) := fin n=>fin m=>ℝ
--
def relu (x : ℝ) : ℝ := max x 0
def length {d : Type} [fintype d] (x : d=>ℝ) : ℝ := (vsum for i, (x i)^2).sqrt
def normalize {d : Type} [fintype d] (x : d=>ℝ) : d=>ℝ := (/) <$> x <*> pure (length x)
def directionAndLength {d : Type} [fintype d] (x : d=>ℝ) : (d=>ℝ) × ℝ :=
let l := length x in ((/) <$> x <*> pure l, l)
def dot {V:Type*} [add_comm_group V] [vector_space ℝ V] {d:Type} [fintype d]
(s:d=>ℝ) (vs:d=>V) : V :=
vsum for j, (s j) • vs j
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment