Skip to content

Instantly share code, notes, and snippets.

@lexi-lambda
Last active November 17, 2020 02:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lexi-lambda/5bec3f33b1db4269fc129242b53b5f43 to your computer and use it in GitHub Desktop.
Save lexi-lambda/5bec3f33b1db4269fc129242b53b5f43 to your computer and use it in GitHub Desktop.
Inductive n-tensor representation proof of concept in Agda
module Matrix where
open import Data.Fin as Fin using (Fin; zero; suc)
open import Data.List as List using (List; []; _∷_; product)
open import Data.Nat as Nat
open import Data.Nat.Properties as Nat
open import Data.Product as Prod
open import Data.Vec as Vec using (Vec; []; _∷_)
open import Level using (Level)
open import Relation.Binary.PropositionalEquality using (_≡_; refl)
-- -----------------------------------------------------------------------------
-- ranges
module Range where
record Range : Set where
field
offset : ℕ
length : ℕ
open Range public
[_,_⟩ : ℕ -> ℕ -> Range
[ i , j ⟩ = record { offset = i; length = j ∸ i }
[_,_] : ℕ -> ℕ -> Range
[ i , j ] = [ i , suc j ⟩
bound₎ : Range -> ℕ
bound₎ r = offset r + length r
module FinRange where
open Range using (Range)
data FinRange : ℕ -> Set where
fromRange : (r : Range) -> (s : ℕ) -> FinRange (Range.bound₎ r + s)
toRange : ∀ {n} -> FinRange n -> Range
toRange (fromRange r _) = r
offset : ∀ {n} -> FinRange n -> ℕ
offset r = Range.offset (toRange r)
length : ∀ {n} -> FinRange n -> ℕ
length r = Range.length (toRange r)
slack : ∀ {n} -> FinRange n -> ℕ
slack (fromRange _ s) = s
[_,_⟩ : ∀ i j {s} -> FinRange (i + (j ∸ i) + s)
[ i , j ⟩ {s} = fromRange (Range.[ i , j ⟩) s
[_,_] : ∀ i j {s} -> FinRange (i + (suc j ∸ i) + s)
[ i , j ] {s} = fromRange (Range.[ i , j ]) s
bound₎ : ∀ {n} -> FinRange n -> ℕ
bound₎ r = Range.bound₎ (toRange r)
decomp : ∀ {n} -> (r : FinRange n) -> n ≡ offset r + length r + slack r
decomp (fromRange _ _) = refl
open FinRange using (FinRange; [_,_⟩; [_,_])
sliceVec : {l : Level} -> {A : Set l} -> {n : _} -> Vec A n -> (r : FinRange n) -> Vec A (FinRange.length r)
sliceVec v r with FinRange.offset r | FinRange.length r | FinRange.slack r | FinRange.decomp r
sliceVec v _ | o | l | s | refl
rewrite +-assoc o l s = Vec.take l (Vec.drop o v)
-- -----------------------------------------------------------------------------
-- basic definitions
data Mat (A : Set) : List ℕ -> Set where
scalar : A -> Mat A []
vector : ∀ {n ns} -> Vec (Mat A ns) n -> Mat A (n ∷ ns)
toVec : ∀ {A ns} -> Mat A ns -> Vec A (product ns)
toVec (scalar x) = Vec.[ x ]
toVec (vector xs) = Vec.concat (Vec.map toVec xs)
fromVec : ∀ {A ns} -> Vec A (product ns) -> Mat A ns
fromVec {_} {[]} (x ∷ []) = scalar x
fromVec {_} {n ∷ ns} xs with Vec.group n (product ns) xs
... | ys , refl = vector (Vec.map fromVec ys)
reshape : ∀ {A ns ms} -> {{product ns ≡ product ms}} -> Mat A ns -> Mat A ms
reshape {{eq}} m with v <- toVec m rewrite eq = fromVec v
-- -----------------------------------------------------------------------------
-- slicing
data Slice : List ℕ -> List ℕ -> Set where
all : ∀ {xs} -> Slice xs xs
index : ∀ {n xs ys} -> Fin n -> Slice xs ys -> Slice (n ∷ xs) ys
range : ∀ {n xs ys} -> (r : FinRange n) -> Slice xs ys -> Slice (n ∷ xs) (FinRange.length r ∷ ys)
slice : ∀ {A xs ys} -> Slice xs ys -> Mat A xs -> Mat A ys
slice all x = x
slice (index i s) (vector xs) = slice s (Vec.lookup xs i)
slice (range r s) (vector xs) = vector (Vec.map (slice s) (sliceVec xs r))
-- -----------------------------------------------------------------------------
-- examples
open import Data.Integer as Int using (ℤ)
mat₁ : Mat ℤ (3 ∷ [])
mat₁ = vector (scalar (Int.+ 1) ∷ scalar (Int.+ 2) ∷ scalar (Int.+ 3) ∷ [])
mat₂ : Mat ℤ (5 ∷ 4 ∷ 3 ∷ [])
mat₂ = fromVec (Vec.tabulate (λ n -> Int.+ (Fin.toℕ n)))
slice₁ : Slice (5 ∷ 4 ∷ 3 ∷ []) (3 ∷ 2 ∷ [])
slice₁ = range [ 2 , 5 ⟩ (index (suc zero) (range [ 0 , 2 ⟩ all))
mat₃ : Mat ℤ (3 ∷ 2 ∷ [])
mat₃ = slice slice₁ mat₂
mat₄ : Mat ℤ (6 ∷ [])
mat₄ = reshape mat₃
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment