Skip to content

Instantly share code, notes, and snippets.

@akeep
Last active October 10, 2020 19:56
Show Gist options
  • Save akeep/a6e28aaf3f014b5d4089864ffa6252c3 to your computer and use it in GitHub Desktop.
Save akeep/a6e28aaf3f014b5d4089864ffa6252c3 to your computer and use it in GitHub Desktop.
Optimizing matrix multiple example, side trial: use a flat byte vector to represent the matrix
#|
usage:
> (import (my-matrix))
> (sanity)
#t
> (run-bench)
500 x 500 matrix multiply in Chez took 2472 msec
500 x 500 matrix multiply in Chez took 2474 msec
...
|#
(library (my-matrix (1))
(export run-bench sanity)
(import (chezscheme))
;;; reference: https://www.scheme.com/tspl3/examples.html
(define-record-type matrix
(nongenerative)
(fields rows columns bv)
(protocol
(lambda (new)
(case-lambda
[(rows columns) (new rows columns (make-bytevector (fx* rows columns 8)))]
[(rows columns val)
(let ([bytes (fx* rows columns 8)])
(let ([bv (make-bytevector bytes)])
(do ([i 0 (fx+ i 8)])
((fx= i bytes))
(bytevector-ieee-double-native-set! bv i val))
(new rows columns bv)))]))))
;;; matrix-ref returns the jth element of the ith row.
(define matrix-ref
(lambda (m i j)
(bytevector-ieee-double-native-ref (matrix-bv m) (fx* (fx+ i (fx* j (matrix-rows m))) 8))))
;;; matrix-set! changes the jth element of the ith row.
(define matrix-set!
(lambda (m i j x)
(bytevector-ieee-double-native-set! (matrix-bv m) (fx* (fx+ i (fx* j (matrix-rows m))) 8) x)))
;;; mul is the generic matrix/scalar multiplication procedure
(define mul
(lambda (x y)
;; mat-sca-mul multiplies a matrix by a scalar.
(define mat-sca-mul
(lambda (m x)
(let* ((nr (matrix-rows m))
(nc (matrix-columns m))
(r (make-matrix nr nc)))
(do ((i 0 (fx+ i 1)))
((fx= i nr) r)
(do ((j 0 (fx+ j 1)))
((fx= j nc))
(matrix-set! r i j
(fl* x (matrix-ref m i j))))))))
;; mat-mat-mul multiplies one matrix by another, after verifying
;; that the first matrix has as many columns as the second
;; matrix has rows.
(define mat-mat-mul
(lambda (m1 m2)
(let ((nr1 (matrix-rows m1))
(nr2 (matrix-rows m2))
(nc2 (matrix-columns m2)))
(unless (fx= (matrix-columns m1) nr2) (match-error m1 m2))
(let ((r (make-matrix nr1 nc2)))
(let ([bv1 (matrix-bv m1)]
[bv2 (matrix-bv m2)]
[bvr (matrix-bv r)])
(do ((i 0 (fx+ i 1)))
((fx= i nr1) r)
(do ((k 0 (fx+ k 1)))
((fx= k nr2))
(do ((j 0 (fx+ j 1)))
((fx= j nc2))
(let ([rindex (fx* (fx+ i (fx* j nr1)) 8)])
(bytevector-ieee-double-native-set! bvr rindex
(fl+ (bytevector-ieee-double-native-ref bvr rindex)
(fl* (bytevector-ieee-double-native-ref bv1 (fx* (fx+ i (fx* k nr1)) 8))
(bytevector-ieee-double-native-ref bv2 (fx* (fx+ k (fx* j nr2)) 8))))))))))))))
;; type-error is called to complain when mul receives an invalid
;; type of argument.
(define type-error
(lambda (what)
(error 'mul
"~s is not a number or matrix"
what)))
;; match-error is called to complain when mul receives a pair of
;; incompatible arguments.
(define match-error
(lambda (what1 what2)
(error 'mul
"~s and ~s are incompatible operands"
what1
what2)))
;; body of mul; dispatch based on input types
(cond
((flonum? x)
(cond
((flonum? y) (fl* x y))
((matrix? y) (mat-sca-mul y x))
(else (type-error y))))
((number? x)
(cond
((number? y) (* x y))
((matrix? y) (mat-sca-mul y x))
(else (type-error y))))
((matrix? x)
(cond
((number? y) (mat-sca-mul x y))
((matrix? y) (mat-mat-mul x y))
(else (type-error y))))
(else (type-error x)))))
(define (fill-random m)
(let* ((nr (matrix-rows m))
(nc (matrix-columns m)))
(do ((i 0 (fx+ i 1)))
((fx= i nr))
(do ((j 0 (fx+ j 1)))
((fx= j nc))
(matrix-set! m i j (fl/ (inexact (random 100)) (fl+ 2.0 (inexact (random 100)))))
))))
(define (bench a b) (mul a b))
(define (run-bench)
(collect)
(do ((runs 0 (fx+ 1 runs))) ((fx>= runs 10))
(do ((sz 500 (fx+ sz 100)))
((fx>= sz 600))
(let* ((a (make-matrix sz sz))
(b (make-matrix sz sz)))
(fill-random a)
(fill-random b)
(let*
((t0 (real-time))
(blah (mul a b))
(t1 (real-time)))
(format #t "~s x ~s matrix multiply in Chez took ~s msec~%" sz sz (- t1 t0))
)))))
(define (fill-sequential m start)
(let ((nr (matrix-rows m))
(nc (matrix-columns m)))
(let outer ([i 0] [val start])
(unless (fx= i nr)
(let inner ([j 0] [val val])
(if (fx= j nc)
(outer (fx+ i 1) val)
(begin
(matrix-set! m i j val)
(inner (fx+ j 1) (fl+ val 1.0)))))))))
(define (matrix-same m1 m2)
(let (
(nr1 (matrix-rows m1))
(nr2 (matrix-rows m2))
(nc1 (matrix-columns m1))
(nc2 (matrix-columns m2))
)
(and (eqv? nc1 nc2)
(eqv? nr1 nr2)
(let outer ([i 0])
(or (fx= i nr1)
(and (let inner ([j 0])
(or (fx= j nc1)
(and
(fl= (matrix-ref m1 i j)
(matrix-ref m2 i j))
(inner (fx+ j 1)))))
(outer (fx+ i 1))))))))
;; sanity test, is our logic right?
;; expect that #(#(1 2) #(3 4)) x #(#(5 6) #(7 8)) == #(#(19 22) #(43 50))
(define (sanity)
(let* (
(expect (make-matrix 2 2))
(a (make-matrix 2 2))
(b (make-matrix 2 2))
)
(matrix-set! expect 0 0 19.0)
(matrix-set! expect 0 1 22.0)
(matrix-set! expect 1 0 43.0)
(matrix-set! expect 1 1 50.0)
(fill-sequential a 1.0)
(fill-sequential b 5.0)
(let* (
(obs (mul a b))
)
;; verify this gives us true
(matrix-same obs expect)
)))
#|
chez scheme timings, on mac book pro:
(with (optimize-level 3); as we go ~ 10% faster)
scheme --optimize-level 3 ./matrix.ss
Chez Scheme Version 9.5.1
Copyright 1984-2017 Cisco Systems, Inc.
;; top-level bindings, without a library wrapper:
500 x 500 matrix multiply in Chez took 2606 msec
500 x 500 matrix multiply in Chez took 2605 msec
500 x 500 matrix multiply in Chez took 2571 msec
500 x 500 matrix multiply in Chez took 2634 msec
500 x 500 matrix multiply in Chez took 2597 msec
500 x 500 matrix multiply in Chez took 2603 msec
500 x 500 matrix multiply in Chez took 2565 msec
500 x 500 matrix multiply in Chez took 2535 msec
500 x 500 matrix multiply in Chez took 2587 msec
500 x 500 matrix multiply in Chez took 2547 msec
;; inside the my-matrix library:
500 x 500 matrix multiply in Chez took 2435 msec
500 x 500 matrix multiply in Chez took 2498 msec
500 x 500 matrix multiply in Chez took 2486 msec
500 x 500 matrix multiply in Chez took 2496 msec
500 x 500 matrix multiply in Chez took 2465 msec
500 x 500 matrix multiply in Chez took 2499 msec
500 x 500 matrix multiply in Chez took 2492 msec
500 x 500 matrix multiply in Chez took 2532 msec
500 x 500 matrix multiply in Chez took 2463 msec
500 x 500 matrix multiply in Chez took 2526 msec
;; update! shifting to (fl*) instead of (*) shaved
;; off 17%
500 x 500 matrix multiply in Chez took 2075 msec
500 x 500 matrix multiply in Chez took 2040 msec
500 x 500 matrix multiply in Chez took 2054 msec
500 x 500 matrix multiply in Chez took 2059 msec
500 x 500 matrix multiply in Chez took 2066 msec
500 x 500 matrix multiply in Chez took 2048 msec
500 x 500 matrix multiply in Chez took 2053 msec
500 x 500 matrix multiply in Chez took 2112 msec
500 x 500 matrix multiply in Chez took 2064 msec
500 x 500 matrix multiply in Chez took 2060 msec
|#
) ;; end my-matrix library
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment