-
-
Save kmill/5eb83cb5fd102e4242c2c8a53224ba13 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- switched to using cartesian products for packed indices | |
-- implemented slicing (see `extraction`) | |
import control.applicative | |
import tactic | |
import data.real.basic | |
import algebra.module.basic | |
universes u v | |
def vec (ι : Type*) (α : Type*) := ι → α | |
infixr ` => `:30 := vec | |
def to_vec {ι : Type*} {α : Type*} (f : ι → α) : ι => α := f | |
def vec.push {ι₁ ι₂ α : Type*} (v : ι₁ => ι₂ => α) : ι₁ × ι₂ => α := | |
λ p, v p.1 p.2 | |
def vec.pop {ι₁ ι₂ α : Type*} (v : ι₁ × ι₂ => α) : ι₁ => ι₂ => α := | |
λ i j, v ⟨i, j⟩ | |
-- unnecessary since v() is the same thing. | |
--def vec.scalar {α : Type*} (v : punit => α) : α := v () | |
section joining_index_types | |
/-! Code put indices into normal form with cartesian products -/ | |
class has_join (α : Type*) (ι : out_param $ Type*) (α' : out_param $ Type*) := | |
(join : α → ι => α') | |
@[priority 1200] | |
instance has_join.reassoc {ι ι' ι'' α : Type*} : has_join ((ι × ι') × ι'' => α) (ι × ι' × ι'') α := | |
{ join := λ v p, v ((equiv.prod_assoc _ _ _).symm p) } | |
@[priority 1100] | |
instance has_join.recurse_p {ι ι' ι'' α α' : Type*} [has_join (ι' => α) ι'' α'] : | |
has_join (ι × ι' => α) (ι × ι'') α' := | |
{ join := λ v p, has_join.join (v.pop p.1) p.2 } | |
@[priority 1000] | |
instance has_join.recurse {ι ι' α α' : Type*} [has_join α ι' α'] : has_join (ι => α) (ι × ι') α' := | |
{ join := λ v p, has_join.join (v p.1) p.2 } | |
@[priority 900] | |
instance has_join.base₁ {ι α : Type*} : has_join (ι => α) ι α := | |
{ join := id } | |
/-- Join all the indices into a single right-associated cartesian product. --/ | |
def vec.join {α ι α'} [has_join α ι α'] : α → ι => α' := has_join.join | |
-- examples | |
--variables (a b c d e : Type) (v : a×b×a=>c×d=>e) (w : (a×b)×c=>d) (u : a=>b=>c) (i : a) (j : b) | |
--#check v.join | |
-- gives vec.join v : a × b × a × c × d => e | |
end joining_index_types | |
-- hom(-, α) functor | |
def vec.reindex {ι ι' α : Type*} (v : ι => α) (f : ι' → ι) : ι' => α := | |
λ i, v (f i) | |
section finite_vectors | |
def tensor.empty {α : Type} : fin 0 => α := fin_zero_elim | |
def tensor.cons {α : Type} {n : ℕ} (h : α) (t : fin n => α) : fin n.succ => α := fin.cons h t | |
notation `![` l:(foldr `, ` (h t, tensor.cons h t) tensor.empty `]`) := l | |
end finite_vectors | |
notation `for` binders `, ` r:(scoped f, to_vec f) := r | |
section monad | |
instance (ι : Type*) : monad (vec ι) := | |
{ map := λ α β f v, for i, f (v i), | |
pure := λ α x, for i, x, | |
seq := λ α β f v, for i, f i (v i), | |
bind := λ α β v f, for i, (f (v i)) i } | |
abbreviation vec.lift {ι : Type*} {α : Type*} (v : α) : ι => α := pure v | |
end monad | |
section extraction | |
/-! code for indexing and slicing simultaneously -/ | |
inductive part (ι : Type u) : Type u | |
| slice : part | |
| index (i : ι) : part | |
open part(slice) | |
class to_part (α : Type u) (ι : Type u) := | |
(coe : α → part ι) | |
attribute [reducible] to_part.coe | |
instance part.from_unit {ι : Type*} : to_part punit ι := | |
{ coe := λ _, part.slice } | |
instance part.from_part {ι : Type*} : to_part (part ι) ι := | |
{ coe := id } | |
instance part.from_type {ι : Type*} : to_part ι ι := | |
{ coe := part.index } | |
@[reducible] | |
def part.type {ι : Type u} : part ι → Type u → Type u | |
| slice α := ι => α | |
| (part.index i) α := α | |
@[reducible] | |
def vec.extract' {ι α β : Type*} (f : α → β) : Π (p : part ι), (ι => α) → part.type p β | |
| slice v := (f <$> v : ι => β) | |
| (part.index i) v := f (v i) | |
notation x `⟦` l:(foldr `, ` (h t, vec.extract' t (to_part.coe h)) id `⟧`) := l x | |
-- Example | |
--variables (a b c : Type) (u : a=>b=>c) (i : a) (j : b) | |
--#reduce u⟦i⟧ | |
-- gives u i | |
--#reduce u⟦(),j⟧ | |
-- gives λ (i : a), u i j | |
--#reduce u⟦i,()⟧ | |
-- gives λ (i_1 : b), u i i_1 | |
end extraction | |
def transpose {r s α : Type*} (v : r => s => α) : s => r => α := | |
for i j, v j i | |
def addvec {r α : Type*} [add_monoid α] (v w : r => α) : r => α := | |
for i, v i + w i | |
def addvec' {r α : Type u} [add_monoid α] (v w : r => α) : r => α := | |
(+) <$> v <*> w | |
def vsum {ι α : Type u} [fintype ι] [add_comm_monoid α] (v : ι => α) : α := | |
finset.sum finset.univ v | |
def matvec {r s α : Type*} [fintype s] [comm_ring α] (A : r => s => α) (v : s => α) : r => α := | |
for i, vsum for j, A i j * v j | |
def matvec' {r s α : Type*} [fintype s] [comm_ring α] (A : r × s => α) (v : s => α) : r => α := | |
for i, vsum for j, A (i, j) * v j | |
def matvec'' {r s α : Type*} [fintype s] [comm_ring α] (A : r × s => α) (v : s => α) : r => α := | |
for i, vsum for j, A.pop i j * v j | |
def matvec''' {r s α : Type*} [fintype s] [comm_ring α] (A : r × s => α) (v : s => α) : r => α := | |
vsum <$> ((*) <$> A <*> v.lift.push).pop | |
def contract {r r' s : Type*} {α : Type*} [fintype s] [comm_ring α] | |
(T : r=>s=>α) (T' : s=>r'=>α) : r=>r'=>α := | |
for i i', vsum ((*) <$> T i <*> T'⟦(),i'⟧) | |
def contract' {r r' r'' s : Type*} {α : Type*} [fintype s] [comm_ring α] | |
(T : r=>s=>r'=>s=>r''=>α) : r=>r'=>r''=>α := | |
for i j k, vsum for l, T i l j l k | |
def contract'' {r r' s : Type*} {α : Type*} [fintype s] [comm_ring α] | |
(T : s=>r=>s=>r'=>α) : r=>r'=>α := | |
for j k, vsum for l, T l j l k | |
def trace' {s α : Type*} [fintype s] [comm_ring α] (T : s=>s=>α) : α := | |
vsum for i, T i i | |
def vec.prod {r r' : Type*} {α : Type*} [monoid α] (T : r=>α) (T' : r'=>α) : r=>r'=>α := | |
for i i', T i * T' i' | |
noncomputable theory | |
--set_option trace.class_instances true | |
def Vec (n) := fin n=>ℝ | |
def Mat (n m) := fin n=>fin m=>ℝ | |
-- | |
def relu (x : ℝ) : ℝ := max x 0 | |
def length {d : Type} [fintype d] (x : d=>ℝ) : ℝ := (vsum for i, (x i)^2).sqrt | |
def normalize {d : Type} [fintype d] (x : d=>ℝ) : d=>ℝ := (/) <$> x <*> pure (length x) | |
def directionAndLength {d : Type} [fintype d] (x : d=>ℝ) : (d=>ℝ) × ℝ := | |
let l := length x in ((/) <$> x <*> pure l, l) | |
def dot {V:Type*} [add_comm_group V] [vector_space ℝ V] {d:Type} [fintype d] | |
(s:d=>ℝ) (vs:d=>V) : V := | |
vsum for j, (s j) • vs j |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment