Skip to content

Instantly share code, notes, and snippets.

@mclements
Last active July 22, 2020 10:54
Show Gist options
  • Save mclements/560be1c891044183d4b83dd2ad8fb5bd to your computer and use it in GitHub Desktop.
Save mclements/560be1c891044183d4b83dd2ad8fb5bd to your computer and use it in GitHub Desktop.
Hack to use MLton with BLAS using the underlying data structure for Array2
infix 7 *!
infix 6 +! -!
structure SeqIndex = struct
open Int
val op +! = Int.+
val op -! = Int.-
val op *! = Int.*
local
fun ltu (lhs, rhs) =
case (compare (lhs, 0), compare (rhs, 0)) of
(LESS, LESS) => lhs < rhs
| (LESS, GREATER) => false
| (_, EQUAL) => false
| (EQUAL, _) => true
| (GREATER, LESS) => true
| (GREATER, GREATER) => lhs < rhs
structure S = IntegralComparisons(type t = Int.int
val (op <) = ltu)
in
val ltu = S.<
val leu = S.<=
val gtu = S.>
val geu = S.>=
end
fun toIntUnsafe i = i
fun fromIntUnsafe i = i
end
structure Primitive = struct
structure V = Vector
open Primitive
structure Array = struct
open Array
val alloc = Unsafe.Array.alloc
val new = Unsafe.Array.create
val unsafeSub = Unsafe.Array.sub
val unsafeUpdate = Unsafe.Array.update
structure Slice = ArraySlice
end
structure Vector = V
end
signature ARRAY2EXT = sig
include ARRAY2
val getArray : 'a array -> 'a Array.array
val makeArray : 'a Array.array * int * int -> 'a array
end
structure Array2ext : ARRAY2EXT = (* new name, new signature, 2 new functions *)
struct
val op +! = SeqIndex.+!
val op + = SeqIndex.+
val op -! = SeqIndex.-!
val op - = SeqIndex.-
val op *! = SeqIndex.*!
val op * = SeqIndex.*
val op < = SeqIndex.<
val op <= = SeqIndex.<=
val op > = SeqIndex.>
val op >= = SeqIndex.>=
val ltu = SeqIndex.ltu
val leu = SeqIndex.leu
val gtu = SeqIndex.gtu
val geu = SeqIndex.geu
type 'a array = {array: 'a Array.array,
rows: SeqIndex.int,
cols: SeqIndex.int}
(* new functions *)
fun 'a getArray({array, ...} : 'a array) = array
fun 'a makeArray(a, m, n) = {array=a, rows=m, cols=n}
fun dimensions' ({rows, cols, ...}: 'a array) = (rows, cols)
fun dimensions ({rows, cols, ...}: 'a array) =
(SeqIndex.toIntUnsafe rows, SeqIndex.toIntUnsafe cols)
fun nRows' ({rows, ...}: 'a array) = rows
fun nRows ({rows, ...}: 'a array) = SeqIndex.toIntUnsafe rows
fun nCols' ({cols, ...}: 'a array) = cols
fun nCols ({cols, ...}: 'a array) = SeqIndex.toIntUnsafe cols
type 'a region = {base: 'a array,
row: int,
col: int,
nrows: int option,
ncols: int option}
local
fun checkSliceMax' (start: int,
num: SeqIndex.int option,
max: SeqIndex.int): SeqIndex.int * SeqIndex.int =
case num of
NONE => if Primitive.Controls.safe
then let
val start =
(SeqIndex.fromInt start)
handle Overflow => raise Subscript
in
if gtu (start, max)
then raise Subscript
else (start, max)
end
else (SeqIndex.fromIntUnsafe start, max)
| SOME num => if Primitive.Controls.safe
then let
val start =
(SeqIndex.fromInt start)
handle Overflow => raise Subscript
in
if (start < 0 orelse num < 0
orelse start +! num > max)
then raise Subscript
else (start, start +! num)
end
else (SeqIndex.fromIntUnsafe start,
SeqIndex.fromIntUnsafe start +! num)
fun checkSliceMax (start: int,
num: int option,
max: SeqIndex.int): SeqIndex.int * SeqIndex.int =
if Primitive.Controls.safe
then (checkSliceMax' (start, Option.map SeqIndex.fromInt num, max))
handle Overflow => raise Subscript
else checkSliceMax' (start, Option.map SeqIndex.fromIntUnsafe num, max)
in
fun checkRegion' {base, row, col, nrows, ncols} =
let
val (rows, cols) = dimensions' base
val (startRow, stopRow) = checkSliceMax' (row, nrows, rows)
val (startCol, stopCol) = checkSliceMax' (col, ncols, cols)
in
{startRow = startRow, stopRow = stopRow,
startCol = startCol, stopCol = stopCol}
end
fun checkRegion {base, row, col, nrows, ncols} =
let
val (rows, cols) = dimensions' base
val (startRow, stopRow) = checkSliceMax (row, nrows, rows)
val (startCol, stopCol) = checkSliceMax (col, ncols, cols)
in
{startRow = startRow, stopRow = stopRow,
startCol = startCol, stopCol = stopCol}
end
end
fun wholeRegion (a : 'a array): 'a region =
{base = a, row = 0, col = 0, nrows = NONE, ncols = NONE}
datatype traversal = RowMajor | ColMajor
local
fun make (rows, cols, doit) =
if Primitive.Controls.safe
andalso (rows < 0 orelse cols < 0)
then raise Size
else {array = doit (rows * cols handle Overflow => raise Size),
rows = rows,
cols = cols}
in
fun alloc' (rows, cols) =
make (rows, cols, Primitive.Array.alloc)
fun array' (rows, cols, init) =
make (rows, cols, fn size => Primitive.Array.new (size, init))
end
local
fun make (rows, cols, doit) =
if Primitive.Controls.safe
then let
val rows =
(SeqIndex.fromInt rows)
handle Overflow => raise Size
val cols =
(SeqIndex.fromInt cols)
handle Overflow => raise Size
in
doit (rows, cols)
end
else doit (SeqIndex.fromIntUnsafe rows,
SeqIndex.fromIntUnsafe cols)
in
fun alloc (rows, cols) =
make (rows, cols, fn (rows, cols) => alloc' (rows, cols))
fun array (rows, cols, init) =
make (rows, cols, fn (rows, cols) => array' (rows, cols, init))
end
fun array0 (): 'a array =
{array = Primitive.Array.alloc 0,
rows = 0,
cols = 0}
fun unsafeSpot' ({cols, ...}: 'a array, r, c) =
r *! cols +! c
fun spot' (a as {rows, cols, ...}: 'a array, r, c) =
if Primitive.Controls.safe
andalso (geu (r, rows) orelse geu (c, cols))
then raise Subscript
else unsafeSpot' (a, r, c)
fun unsafeSub' (a as {array, ...}: 'a array, r, c) =
Primitive.Array.unsafeSub (array, unsafeSpot' (a, r, c))
fun sub' (a as {array, ...}: 'a array, r, c) =
Primitive.Array.unsafeSub (array, spot' (a, r, c))
fun unsafeUpdate' (a as {array, ...}: 'a array, r, c, x) =
Primitive.Array.unsafeUpdate (array, unsafeSpot' (a, r, c), x)
fun update' (a as {array, ...}: 'a array, r, c, x) =
Primitive.Array.unsafeUpdate (array, spot' (a, r, c), x)
local
fun make (r, c, doit) =
if Primitive.Controls.safe
then let
val r =
(SeqIndex.fromInt r)
handle Overflow => raise Subscript
val c =
(SeqIndex.fromInt c)
handle Overflow => raise Subscript
in
doit (r, c)
end
else doit (SeqIndex.fromIntUnsafe r,
SeqIndex.fromIntUnsafe c)
in
fun sub (a, r, c) =
make (r, c, fn (r, c) => sub' (a, r, c))
fun update (a, r, c, x) =
make (r, c, fn (r, c) => update' (a, r, c, x))
end
fun 'a fromList (rows: 'a list list): 'a array =
case rows of
[] => array0 ()
| row1 :: _ =>
let
val cols = length row1
val a as {array, cols = cols', ...} =
alloc (length rows, cols)
val _ =
List.foldl
(fn (row: 'a list, i) =>
let
val max = i +! cols'
val i' =
List.foldl (fn (x: 'a, i) =>
(if i >= max
then raise Size
else (Primitive.Array.unsafeUpdate (array, i, x)
; i +! 1)))
i row
in if i' = max
then i'
else raise Size
end)
0 rows
in
a
end
fun row' ({array, rows, cols}, r) =
if Primitive.Controls.safe andalso geu (r, rows)
then raise Subscript
else
ArraySlice.vector (Primitive.Array.Slice.slice (array, r *! cols, SOME cols))
fun row (a, r) =
if Primitive.Controls.safe
then let
val r =
(SeqIndex.fromInt r)
handle Overflow => raise Subscript
in
row' (a, r)
end
else row' (a, SeqIndex.fromIntUnsafe r)
fun column' (a as {rows, cols, ...}: 'a array, c) =
if Primitive.Controls.safe andalso geu (c, cols)
then raise Subscript
else
Primitive.Vector.tabulate (rows, fn r => unsafeSub' (a, r, c))
fun column (a, c) =
if Primitive.Controls.safe
then let
val c =
(SeqIndex.fromInt c)
handle Overflow => raise Subscript
in
column' (a, c)
end
else column' (a, SeqIndex.fromIntUnsafe c)
fun foldi' trv f b (region as {base, ...}) =
let
val {startRow, stopRow, startCol, stopCol} = checkRegion region
in
case trv of
RowMajor =>
let
fun loopRow (r, b) =
if r >= stopRow then b
else let
fun loopCol (c, b) =
if c >= stopCol then b
else loopCol (c +! 1, f (r, c, sub' (base, r, c), b))
in
loopRow (r +! 1, loopCol (startCol, b))
end
in
loopRow (startRow, b)
end
| ColMajor =>
let
fun loopCol (c, b) =
if c >= stopCol then b
else let
fun loopRow (r, b) =
if r >= stopRow then b
else loopRow (r +! 1, f (r, c, sub' (base, r, c), b))
in
loopCol (c +! 1, loopRow (startRow, b))
end
in
loopCol (startCol, b)
end
end
fun foldi trv f b a =
foldi' trv (fn (r, c, x, b) =>
f (SeqIndex.toIntUnsafe r,
SeqIndex.toIntUnsafe c,
x, b)) b a
fun fold trv f b a =
foldi trv (fn (_, _, x, b) => f (x, b)) b (wholeRegion a)
fun appi trv f =
foldi trv (fn (r, c, x, ()) => f (r, c, x)) ()
fun app trv f = fold trv (f o #1) ()
fun modifyi trv f (r as {base, ...}) =
appi trv (fn (r, c, x) => update (base, r, c, f (r, c, x))) r
fun modify trv f a = modifyi trv (f o #3) (wholeRegion a)
fun tabulate trv (rows, cols, f) =
let
val a = alloc (rows, cols)
val () = modifyi trv (fn (r, c, _) => f (r, c)) (wholeRegion a)
in
a
end
fun copy {src = src as {base, ...}: 'a region,
dst, dst_row, dst_col} =
let
val {startRow, stopRow, startCol, stopCol} = checkRegion src
val nrows = stopRow -! startRow
val ncols = stopCol -! startCol
val {startRow = dst_row, startCol = dst_col, ...} =
checkRegion' {base = dst, row = dst_row, col = dst_col,
nrows = SOME nrows,
ncols = SOME ncols}
fun forUp (start, stop, f: SeqIndex.int -> unit) =
let
fun loop i =
if i >= stop
then ()
else (f i; loop (i + 1))
in loop start
end
fun forDown (start, stop, f: SeqIndex.int -> unit) =
let
fun loop i =
if i < start
then ()
else (f i; loop (i - 1))
in loop (stop -! 1)
end
val forRows = if startRow <= dst_row then forDown else forUp
val forCols = if startCol <= dst_col then forUp else forDown
in forRows (0, nrows, fn r =>
forCols (0, ncols, fn c =>
unsafeUpdate' (dst, dst_row +! r, dst_col +! c,
unsafeSub' (base, startRow +! r, startCol +! c))))
end
end
local
datatype cblasTranspose = NoTrans | Trans | ConjTrans | ConjNoTrans
fun cblasOrder Array2ext.RowMajor = 101
| cblasOrder Array2ext.ColMajor = 102
fun cblasTranspose NoTrans = 111
| cblasTranspose Trans = 112
| cblasTranspose ConjTrans = 113
| cblasTranspose ConjNoTrans = 114
val call = _import "cblas_dgemm" public: int * int * int * int * int * int * real * real Vector.vector * int * real Vector.vector * int * real * real Array.array * int -> unit;
in
fun matmul3(a : real Array2ext.array,
b : real Array2ext.array) : real Array2ext.array =
let
open Array2ext
val ((m,k), (k',n)) = (dimensions a, dimensions b)
val () = if k <> k' then raise General.Size else ()
val arrayc = Array.array(m*n,0.0)
val getVector = Array.vector o getArray
val _ = call(cblasOrder RowMajor, cblasTranspose NoTrans, cblasTranspose NoTrans, m, n, k, 1.0, getVector a, k, getVector b, n, 0.0, arrayc, n)
in
makeArray (arrayc, m, n)
end
end;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment