Last active
May 4, 2021 06:10
-
-
Save gwerbin/40309efeb3324677f2faa12fb76bd2a2 to your computer and use it in GitHub Desktop.
Sketch of a matrix data type in Idris
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Matrix | |
import Data.Vect | |
%default total | |
---- Interface ---- | |
AbstractMatrix : Type | |
AbstractMatrix = (m : Nat) -> (n : Nat) -> (t : Type) -> Type | |
interface Matrix (0 matrixType : AbstractMatrix) | matrixType where | |
{- Size and shape -} | |
-- Implementations should prove correctness: | |
-- nrows x = m | |
-- ncols x = n | |
-- size x = m * n | |
||| Number of rows | |
nrows : {m : _} -> {n : _} -> matrixType m n t -> Nat | |
nrows _ = m | |
||| Number of columns | |
ncols : {m : _} -> {n : _} -> matrixType m n t -> Nat | |
ncols _ = n | |
||| Number of entries | |
size : {m : _} -> {n : _} -> matrixType m n t -> Nat | |
size _ = m * n | |
shape : {m : _} -> {n : _} -> matrixType m n t -> (Nat, Nat) | |
shape _ = (m, n) | |
{- Accessing parts -} | |
-- TODO: blocks, sub-matrices, strides, diagonals | |
||| Get one row | |
row : (i : Fin m) -> matrixType m n t -> Vect n t | |
||| Get one column | |
col : (j : Fin n) -> matrixType m n t -> Vect m t | |
||| Get one entry | |
index : (i : Fin m) -> (j : Fin n) -> matrixType m n t -> t | |
{- Constructing -} | |
-- Implementations should prove correctness: | |
-- insertRow FZ = prependRow | |
-- insertCol FZ = prependCol | |
-- insertRow m = appendRow | |
-- insertCol n = appendCol | |
||| Insert a row at a given position, moving lower rows down by 1 | |
insertRow : {m : _} -> | |
{n : _} -> | |
(i : Fin (S m)) -> | |
Vect n t -> | |
matrixType m n t -> | |
matrixType (S m) n t | |
||| Insert a column at a given position, moving other rows over by 1 | |
insertCol : {m : _} -> | |
{n : _} -> | |
(j : Fin (S n)) -> | |
Vect m t -> | |
matrixType m n t -> | |
matrixType m (S n) t | |
||| Prepend a row to the top of a matrix | |
prependRow : {m : _} -> | |
{n : _} -> | |
Vect n t -> | |
matrixType m n t -> | |
matrixType (S m) n t | |
prependRow = insertRow FZ | |
||| Prepend a column to the left side of a matrix | |
prependCol : {m : _} -> | |
{n : _} -> | |
Vect m t -> | |
matrixType m n t -> | |
matrixType m (S n) t | |
prependCol = insertCol FZ | |
||| Append a row to the bottom of a matrix | |
appendRow : {m : _} -> | |
{n : _} -> | |
Vect n t -> | |
matrixType m n t -> | |
matrixType (S m) n t | |
appendRow = insertRow last | |
||| Append a column to the right side of a matrix | |
appendCol : {m : _} -> | |
{n : _} -> | |
Vect m t -> | |
matrixType m n t -> | |
matrixType m (S n) t | |
appendCol = insertCol last | |
||| Concatenate matrices "horizontally" | |
concatRows : matrixType m1 n t -> | |
matrixType m2 n t -> | |
matrixType (m1 + m2) n t | |
||| Concatenate matrices "vertically" | |
concatCols : matrixType m n1 t -> | |
matrixType m n2 t -> | |
matrixType m (n1 + n2) t | |
{- Operations -} | |
-- Implementations should prove correctness: | |
-- transpose . transpose = id | |
-- (rowSwap i j) . (rowSwap j i) = id | |
||| Matrix transposition | |
transpose : {m : _} -> {n : _} -> matrixType m n t -> matrixType n m t | |
||| Swap two rows | |
rowSwap : (rownum1 : Fin m) -> (rownum2 : Fin m) -> matrixType m n t -> matrixType m n t | |
||| Replace a row with the sum of two rows | |
rowAdd : Num t => (rownum1 : Fin m) -> (rownum2 : Fin m) -> matrixType m n t -> matrixType m n t | |
||| Replace a row with itself, multiplied by a scalar | |
rowScale : Num t => (scalar : t) -> (rownum : Fin m) -> matrixType m n t -> matrixType m n t | |
||| Matrix addition: A + B | |
add : Num t => matrixType m n t -> matrixType m n t -> matrixType m n t | |
--||| Matrix subtraction: A - B i.e. A + (−1 B) | |
--subtract : Num t => matrixType m n t -> matrixType m n t -> t | |
||| Scalar multiplication: c A | |
scale : Num t => t -> matrixType m n t -> matrixType m n t | |
--||| Scalar division: (1/c) A | |
--scaleDiv : Num t => t -> matrixType m n t -> matrixType m n t | |
||| Matrix multiplication: A B | |
matmul : Num t => t -> matrixType m k t -> matrixType k n t -> matrixType m n t | |
---- Implementation: Row-oriented vector-of-vectors ---- | |
{- Type -} | |
RowMatrix : (m : Nat) -> (n : Nat) -> (t : Type) -> Type | |
RowMatrix m n t = Vect m (Vect n t) | |
{- Construction -} | |
filledRowMatrix : (fillval : filltype) -> (nrow : Nat) -> (ncol : Nat) -> RowMatrix nrow ncol filltype | |
filledRowMatrix fillval 0 0 = [] | |
filledRowMatrix fillval 0 _ = [] | |
filledRowMatrix fillval (S m) 0 = Data.Vect.(::) [] $ filledRowMatrix fillval m 0 | |
filledRowMatrix fillval m n = Data.Vect.replicate m $ Data.Vect.replicate n fillval | |
zerosRowMatrix : (nrow : Nat) -> (ncol : Nat) -> RowMatrix nrow ncol Double | |
zerosRowMatrix = filledRowMatrix 0.0 | |
onesRowMatrix : (nrow : Nat) -> (ncol : Nat) -> RowMatrix nrow ncol Double | |
onesRowMatrix = filledRowMatrix 1.0 | |
{- Core Matrix interface -} | |
implementation Matrix RowMatrix where | |
{- Size and shape -} | |
-- All default implementations | |
{- Accessing parts -} | |
row = Data.Vect.index | |
col j = map (Data.Vect.index j) | |
index i j = (Data.Vect.index j) . (Data.Vect.index i) | |
{- Constructing -} | |
insertRow = insertAt | |
insertCol j = zipWith (insertAt j) | |
prependRow = Data.Vect.(::) | |
prependCol = zipWith (::) | |
appendRow = flip snoc | |
appendCol col rows = zipWith (snoc) rows col | |
concatRows = Data.Vect.(++) | |
concatCols = zipWith (++) | |
transpose = Data.Vect.transpose | |
{- Operations -} | |
-- Very inefficient use of Vect operations; TODO make these better | |
rowSwap i1 i2 rows = (swap21 . swap12) rows where | |
swap12 : RowMatrix m n t -> RowMatrix m n t | |
swap12 rows = replaceAt i1 (Data.Vect.index i2 rows) rows | |
swap21 : RowMatrix m n t -> RowMatrix m n t | |
swap21 rows = replaceAt i2 (Data.Vect.index i1 rows) rows | |
rowAdd i1 i2 rows = replaceAt i1 (zipWith (+) (Data.Vect.index i1 rows) (Data.Vect.index i2 rows)) rows | |
rowScale c i rows = updateAt i (map (* c)) rows | |
add = zipWith vectAdd where | |
vectAdd : Vect k t -> Vect k t -> Vect k t | |
vectAdd x y = zipWith (+) x y | |
scale c = map (vectScale c) where | |
vectScale : t -> Vect k t -> Vect k t | |
vectScale c = map (* c) | |
--matmul : matrixType m k t -> matrixType k n t -> matrixType m n t | |
implementation [rowMatrix] Show t => Show (RowMatrix m n t) where | |
show [] = "[]" | |
show (row1@[]::rows) = "[]" | |
show (row1@(_::_)::rows) = | |
"[ " ++ (rowToStr row1) ++ "\n" ++ strJoin "\n" (map (indent . rowToStr) rows) ++ " ]" | |
where | |
strJoin2 : String -> String -> String -> String | |
strJoin2 sep s1 s2 = s1 ++ sep ++ s2 | |
strJoin : String -> Vect k String -> String | |
strJoin sep [] = "" | |
strJoin sep x@(_::_) = foldr1 (strJoin2 sep) x | |
rowToStr : Vect (S k) t -> String | |
rowToStr = (strJoin " ") . (map show) | |
indent : String -> String | |
indent = (++) " " | |
main : IO () | |
main = | |
let i = 1 | |
row = [7.0, 8.0, 9.0] | |
matrix = zerosRowMatrix 2 3 | |
in | |
(putStrLn . show @{rowMatrix}) $ insertRow {matrixType=RowMatrix} i row matrix | |
--printLn {a = RowMatrix _ _} $ insertRow {matrixType=RowMatrix} i row matrix |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment