Skip to content

Instantly share code, notes, and snippets.

@gwerbin
Last active May 4, 2021 06:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gwerbin/40309efeb3324677f2faa12fb76bd2a2 to your computer and use it in GitHub Desktop.
Save gwerbin/40309efeb3324677f2faa12fb76bd2a2 to your computer and use it in GitHub Desktop.
Sketch of a matrix data type in Idris
module Matrix
import Data.Vect
%default total
---- Interface ----
AbstractMatrix : Type
AbstractMatrix = (m : Nat) -> (n : Nat) -> (t : Type) -> Type
interface Matrix (0 matrixType : AbstractMatrix) | matrixType where
{- Size and shape -}
-- Implementations should prove correctness:
-- nrows x = m
-- ncols x = n
-- size x = m * n
||| Number of rows
nrows : {m : _} -> {n : _} -> matrixType m n t -> Nat
nrows _ = m
||| Number of columns
ncols : {m : _} -> {n : _} -> matrixType m n t -> Nat
ncols _ = n
||| Number of entries
size : {m : _} -> {n : _} -> matrixType m n t -> Nat
size _ = m * n
shape : {m : _} -> {n : _} -> matrixType m n t -> (Nat, Nat)
shape _ = (m, n)
{- Accessing parts -}
-- TODO: blocks, sub-matrices, strides, diagonals
||| Get one row
row : (i : Fin m) -> matrixType m n t -> Vect n t
||| Get one column
col : (j : Fin n) -> matrixType m n t -> Vect m t
||| Get one entry
index : (i : Fin m) -> (j : Fin n) -> matrixType m n t -> t
{- Constructing -}
-- Implementations should prove correctness:
-- insertRow FZ = prependRow
-- insertCol FZ = prependCol
-- insertRow m = appendRow
-- insertCol n = appendCol
||| Insert a row at a given position, moving lower rows down by 1
insertRow : {m : _} ->
{n : _} ->
(i : Fin (S m)) ->
Vect n t ->
matrixType m n t ->
matrixType (S m) n t
||| Insert a column at a given position, moving other rows over by 1
insertCol : {m : _} ->
{n : _} ->
(j : Fin (S n)) ->
Vect m t ->
matrixType m n t ->
matrixType m (S n) t
||| Prepend a row to the top of a matrix
prependRow : {m : _} ->
{n : _} ->
Vect n t ->
matrixType m n t ->
matrixType (S m) n t
prependRow = insertRow FZ
||| Prepend a column to the left side of a matrix
prependCol : {m : _} ->
{n : _} ->
Vect m t ->
matrixType m n t ->
matrixType m (S n) t
prependCol = insertCol FZ
||| Append a row to the bottom of a matrix
appendRow : {m : _} ->
{n : _} ->
Vect n t ->
matrixType m n t ->
matrixType (S m) n t
appendRow = insertRow last
||| Append a column to the right side of a matrix
appendCol : {m : _} ->
{n : _} ->
Vect m t ->
matrixType m n t ->
matrixType m (S n) t
appendCol = insertCol last
||| Concatenate matrices "horizontally"
concatRows : matrixType m1 n t ->
matrixType m2 n t ->
matrixType (m1 + m2) n t
||| Concatenate matrices "vertically"
concatCols : matrixType m n1 t ->
matrixType m n2 t ->
matrixType m (n1 + n2) t
{- Operations -}
-- Implementations should prove correctness:
-- transpose . transpose = id
-- (rowSwap i j) . (rowSwap j i) = id
||| Matrix transposition
transpose : {m : _} -> {n : _} -> matrixType m n t -> matrixType n m t
||| Swap two rows
rowSwap : (rownum1 : Fin m) -> (rownum2 : Fin m) -> matrixType m n t -> matrixType m n t
||| Replace a row with the sum of two rows
rowAdd : Num t => (rownum1 : Fin m) -> (rownum2 : Fin m) -> matrixType m n t -> matrixType m n t
||| Replace a row with itself, multiplied by a scalar
rowScale : Num t => (scalar : t) -> (rownum : Fin m) -> matrixType m n t -> matrixType m n t
||| Matrix addition: A + B
add : Num t => matrixType m n t -> matrixType m n t -> matrixType m n t
--||| Matrix subtraction: A - B i.e. A + (−1 B)
--subtract : Num t => matrixType m n t -> matrixType m n t -> t
||| Scalar multiplication: c A
scale : Num t => t -> matrixType m n t -> matrixType m n t
--||| Scalar division: (1/c) A
--scaleDiv : Num t => t -> matrixType m n t -> matrixType m n t
||| Matrix multiplication: A B
matmul : Num t => t -> matrixType m k t -> matrixType k n t -> matrixType m n t
---- Implementation: Row-oriented vector-of-vectors ----
{- Type -}
RowMatrix : (m : Nat) -> (n : Nat) -> (t : Type) -> Type
RowMatrix m n t = Vect m (Vect n t)
{- Construction -}
filledRowMatrix : (fillval : filltype) -> (nrow : Nat) -> (ncol : Nat) -> RowMatrix nrow ncol filltype
filledRowMatrix fillval 0 0 = []
filledRowMatrix fillval 0 _ = []
filledRowMatrix fillval (S m) 0 = Data.Vect.(::) [] $ filledRowMatrix fillval m 0
filledRowMatrix fillval m n = Data.Vect.replicate m $ Data.Vect.replicate n fillval
zerosRowMatrix : (nrow : Nat) -> (ncol : Nat) -> RowMatrix nrow ncol Double
zerosRowMatrix = filledRowMatrix 0.0
onesRowMatrix : (nrow : Nat) -> (ncol : Nat) -> RowMatrix nrow ncol Double
onesRowMatrix = filledRowMatrix 1.0
{- Core Matrix interface -}
implementation Matrix RowMatrix where
{- Size and shape -}
-- All default implementations
{- Accessing parts -}
row = Data.Vect.index
col j = map (Data.Vect.index j)
index i j = (Data.Vect.index j) . (Data.Vect.index i)
{- Constructing -}
insertRow = insertAt
insertCol j = zipWith (insertAt j)
prependRow = Data.Vect.(::)
prependCol = zipWith (::)
appendRow = flip snoc
appendCol col rows = zipWith (snoc) rows col
concatRows = Data.Vect.(++)
concatCols = zipWith (++)
transpose = Data.Vect.transpose
{- Operations -}
-- Very inefficient use of Vect operations; TODO make these better
rowSwap i1 i2 rows = (swap21 . swap12) rows where
swap12 : RowMatrix m n t -> RowMatrix m n t
swap12 rows = replaceAt i1 (Data.Vect.index i2 rows) rows
swap21 : RowMatrix m n t -> RowMatrix m n t
swap21 rows = replaceAt i2 (Data.Vect.index i1 rows) rows
rowAdd i1 i2 rows = replaceAt i1 (zipWith (+) (Data.Vect.index i1 rows) (Data.Vect.index i2 rows)) rows
rowScale c i rows = updateAt i (map (* c)) rows
add = zipWith vectAdd where
vectAdd : Vect k t -> Vect k t -> Vect k t
vectAdd x y = zipWith (+) x y
scale c = map (vectScale c) where
vectScale : t -> Vect k t -> Vect k t
vectScale c = map (* c)
--matmul : matrixType m k t -> matrixType k n t -> matrixType m n t
implementation [rowMatrix] Show t => Show (RowMatrix m n t) where
show [] = "[]"
show (row1@[]::rows) = "[]"
show (row1@(_::_)::rows) =
"[ " ++ (rowToStr row1) ++ "\n" ++ strJoin "\n" (map (indent . rowToStr) rows) ++ " ]"
where
strJoin2 : String -> String -> String -> String
strJoin2 sep s1 s2 = s1 ++ sep ++ s2
strJoin : String -> Vect k String -> String
strJoin sep [] = ""
strJoin sep x@(_::_) = foldr1 (strJoin2 sep) x
rowToStr : Vect (S k) t -> String
rowToStr = (strJoin " ") . (map show)
indent : String -> String
indent = (++) " "
main : IO ()
main =
let i = 1
row = [7.0, 8.0, 9.0]
matrix = zerosRowMatrix 2 3
in
(putStrLn . show @{rowMatrix}) $ insertRow {matrixType=RowMatrix} i row matrix
--printLn {a = RowMatrix _ _} $ insertRow {matrixType=RowMatrix} i row matrix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment