Created
July 16, 2020 07:25
-
-
Save hyotang666/4955cc17ea9c7b052542456396eb5649 to your computer and use it in GitHub Desktop.
Naive implementation of dot function with numcl.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| https://numpy.org/doc/stable/reference/generated/numpy.dot.html | |
Dot product of two arrays. Specifically, | |
If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation). | |
If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred. | |
If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred. | |
If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b. | |
If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b: | |
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) | |
|# | |
(defun dot (a b) | |
(cond | |
;; If either a or b is 0-D (scalar), | |
;; it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred. | |
((or (not (numcl:numcl-array-p a)) | |
(not (numcl:numcl-array-p b))) | |
(numcl:* a b)) | |
;; If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation). | |
((and (= 1 (array-rank a)) (= 1 (array-rank b))) | |
(numcl:inner a b)) | |
;; If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred. | |
((and (= 2 (array-rank a)) (= 2 (array-rank b))) | |
(numcl:matmul a b)) | |
;; If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b. | |
((and (<= 2 (array-rank a)) (= 1 (array-rank b))) | |
(destructuring-bind (last-axis . rest) (reverse (array-dimensions a)) | |
(numcl:reshape (numcl:einsum '(ij j -> i) | |
(numcl:reshape a (list (apply '* rest) last-axis)) | |
b) | |
(reverse rest)))) | |
;; If a is an N-D array and b is an M-D array (where M>=2), | |
;; it is a sum product over the last axis of a and the second-to-last axis of b: | |
((<= 2 (array-rank b)) | |
(destructuring-bind (last-axis-a . rest-a) (reverse (array-dimensions a)) | |
(destructuring-bind (last-axis-b last2-axis-b . rest-b) (reverse (array-dimensions b)) | |
(cond | |
;; Tensor * Tensor. | |
((and rest-a rest-b) | |
(numcl:reshape (numcl:einsum '(ij kjl -> ikl) | |
(numcl:reshape a (list (apply #'* rest-a) last-axis-a)) | |
(numcl:reshape b (list (apply #'* rest-b) last2-axis-b last-axis-b))) | |
(append (reverse rest-a)(reverse rest-b) (list last-axis-b)))) | |
;; Vector * Matrix. | |
((and (null rest-a) (null rest-b)) | |
(numcl:einsum '(i ij -> j) a b)) | |
;; Vector * Tensor. | |
((and (null rest-a) rest-b) | |
(numcl:reshape (numcl:einsum '(i jik -> jk) | |
a | |
(numcl:reshape b (list (apply #'* rest-b) | |
last2-axis-b last-axis-b))) | |
(reverse (cons last-axis-b rest-b)))) | |
;; Tensor * Matrix. | |
((and rest-a (null rest-b)) | |
(numcl:reshape (numcl:matmul (numcl:reshape a (list (apply #'* rest-a) last-axis-a)) | |
b) | |
(reverse (cons last-axis-a rest-a)))))))) | |
(t (error "NIY")))) | |
(assert (= 12 (dot 3 4))) | |
(assert (equalp (dot (numcl:asarray '((1 0) (0 1))) | |
(numcl:asarray '((4 1) (2 2)))) | |
(numcl:asarray '((4 1) (2 2))))) | |
(assert (= 499128 (numcl:aref (dot (numcl:reshape (numcl:arange (* 3 4 5 6)) '(3 4 5 6)) | |
(numcl:reshape (numcl:asarray (reverse (alexandria:iota (* 3 4 5 6)))) | |
'(5 4 6 3))) | |
2 3 2 1 2 2))) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment