Created
September 11, 2020 08:28
-
-
Save mariogeiger/47d53d1db0141db316a3d48c3b8269a8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import functools | |
def einsum(operand, *tensors): | |
op_in, op_out = operand.split('->') | |
op_in = op_in.split(',') | |
assert all(len(op) == x.ndim for op, x in zip(op_in, tensors)) | |
op_out = list(op_out) | |
ix_sum = list({i for op in op_in for i in op if i not in op_out}) | |
dx = { | |
i: next(x.shape[op.index(i)] for op, x in zip(op_in, tensors) if i in op) | |
for i in op_out + ix_sum | |
} | |
d_out = functools.reduce(int.__mul__, (dx[i] for i in op_out), 1) | |
d_sum = functools.reduce(int.__mul__, (dx[i] for i in ix_sum), 1) | |
output = [] | |
for I in range(d_out): | |
vx = {} | |
for i in reversed(op_out): | |
vx[i] = I % dx[i] | |
I = I // dx[i] | |
s = 0 | |
for J in range(d_sum): | |
for i in ix_sum: | |
vx[i] = J % dx[i] | |
J = J // dx[i] | |
p = 1 | |
for op, x in zip(op_in, tensors): | |
K = 0 | |
for i in op: | |
K *= dx[i] | |
K += vx[i] | |
p *= x.flatten()[K] | |
s += p | |
output.append(s) | |
return np.array(output).reshape([dx[i] for i in op_out]) | |
op = 'ij,jk->ik' | |
a = np.random.randn(4, 5) | |
b = np.random.randn(5, 3) | |
(einsum(op, a, b) == np.einsum(op, a, b)).all() | |
op = 'ii->' | |
a = np.random.randn(4, 4) | |
(einsum(op, a) == np.einsum(op, a)).all() | |
op = 'i,j,k->ijk' | |
a = np.random.randn(4) | |
b = np.random.randn(5) | |
c = np.random.randn(6) | |
(einsum(op, a, b, c) == np.einsum(op, a, b, c)).all() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that the complexity is shity!
For example take
'ij,jk,kl->il'
with dimensions(1000, 2), (2, 1000), (1000, 2)
This naive algorithm will make
1000*2*2*1000
multiplications to perform itA better algorithm that contracts
k
first andj
second will do2*2*1000 + 1000*2*2 = 4000
multiplicationssee opt_einsum for a clever algorithm for einsum