Skip to content

Instantly share code, notes, and snippets.

@mclements
Created July 21, 2020 16:11
Show Gist options
  • Save mclements/8641f8c3376d5ea3f0da3464541bb11f to your computer and use it in GitHub Desktop.
Save mclements/8641f8c3376d5ea3f0da3464541bb11f to your computer and use it in GitHub Desktop.
SML / MLton: naive implementation for BLAS-based matrix multiplication
local
val call = _import "cblas_dgemm" public: int * int * int * int * int * int * real * real Vector.vector * int * real Vector.vector * int * real * real Array.array * int -> unit;
datatype cblasTranspose = NoTrans | Trans | ConjTrans | ConjNoTrans
fun cblasOrder Array2.RowMajor = 101
| cblasOrder Array2.ColMajor = 102
fun cblasTranspose NoTrans = 111
| cblasTranspose Trans = 112
| cblasTranspose ConjTrans = 113
| cblasTranspose ConjNoTrans = 114
fun getVector a =
let open Array2
val (m,n) = dimensions a
val a'= Array.array(m*n, 0.0)
val _ = appi RowMajor (fn (i,j,aij) => Array.update(a',i+m*j,aij)) {base=a, row=0, col=0, nrows=NONE, ncols=NONE}
in
Array.vector a'
end
fun makeArray(a,m,n) =
Array2.tabulate Array2.RowMajor (m, n, fn (i,j) => Array.sub(a,i+m*j))
in
fun matmul2(a, b) =
let
open Array2
val ((m,k), (k',n)) = (dimensions a, dimensions b)
val () = if k <> k' then raise General.Size else ()
val arrayc = Array.array(m*n,0.0)
val _ = call(cblasOrder ColMajor, cblasTranspose NoTrans, cblasTranspose NoTrans, m, n, k, 1.0, getVector a, m, getVector b, k, 0.0, arrayc, m)
in
makeArray (arrayc, m, n)
end
end;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment