Skip to content

Instantly share code, notes, and snippets.

@hyotang666
Created July 16, 2020 07:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hyotang666/4955cc17ea9c7b052542456396eb5649 to your computer and use it in GitHub Desktop.
Save hyotang666/4955cc17ea9c7b052542456396eb5649 to your computer and use it in GitHub Desktop.
Naive implementation of dot function with numcl.
#| https://numpy.org/doc/stable/reference/generated/numpy.dot.html
Dot product of two arrays. Specifically,
If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation).
If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred.
If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.
If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
|#
(defun dot (a b)
(cond
;; If either a or b is 0-D (scalar),
;; it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred.
((or (not (numcl:numcl-array-p a))
(not (numcl:numcl-array-p b)))
(numcl:* a b))
;; If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation).
((and (= 1 (array-rank a)) (= 1 (array-rank b)))
(numcl:inner a b))
;; If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred.
((and (= 2 (array-rank a)) (= 2 (array-rank b)))
(numcl:matmul a b))
;; If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.
((and (<= 2 (array-rank a)) (= 1 (array-rank b)))
(destructuring-bind (last-axis . rest) (reverse (array-dimensions a))
(numcl:reshape (numcl:einsum '(ij j -> i)
(numcl:reshape a (list (apply '* rest) last-axis))
b)
(reverse rest))))
;; If a is an N-D array and b is an M-D array (where M>=2),
;; it is a sum product over the last axis of a and the second-to-last axis of b:
((<= 2 (array-rank b))
(destructuring-bind (last-axis-a . rest-a) (reverse (array-dimensions a))
(destructuring-bind (last-axis-b last2-axis-b . rest-b) (reverse (array-dimensions b))
(cond
;; Tensor * Tensor.
((and rest-a rest-b)
(numcl:reshape (numcl:einsum '(ij kjl -> ikl)
(numcl:reshape a (list (apply #'* rest-a) last-axis-a))
(numcl:reshape b (list (apply #'* rest-b) last2-axis-b last-axis-b)))
(append (reverse rest-a)(reverse rest-b) (list last-axis-b))))
;; Vector * Matrix.
((and (null rest-a) (null rest-b))
(numcl:einsum '(i ij -> j) a b))
;; Vector * Tensor.
((and (null rest-a) rest-b)
(numcl:reshape (numcl:einsum '(i jik -> jk)
a
(numcl:reshape b (list (apply #'* rest-b)
last2-axis-b last-axis-b)))
(reverse (cons last-axis-b rest-b))))
;; Tensor * Matrix.
((and rest-a (null rest-b))
(numcl:reshape (numcl:matmul (numcl:reshape a (list (apply #'* rest-a) last-axis-a))
b)
(reverse (cons last-axis-a rest-a))))))))
(t (error "NIY"))))
(assert (= 12 (dot 3 4)))
(assert (equalp (dot (numcl:asarray '((1 0) (0 1)))
(numcl:asarray '((4 1) (2 2))))
(numcl:asarray '((4 1) (2 2)))))
(assert (= 499128 (numcl:aref (dot (numcl:reshape (numcl:arange (* 3 4 5 6)) '(3 4 5 6))
(numcl:reshape (numcl:asarray (reverse (alexandria:iota (* 3 4 5 6))))
'(5 4 6 3)))
2 3 2 1 2 2)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment