Take again matrix multiplication. The following three macros all end up calling the same functions as does C = A * B
:
@tensor C[i,j] := A[i,k] * B[k,j] # TensorOperations.jl
@ein C[i,j] := A[i,k] * B[k,j] # OMEinsum.jl
@matmul C[i,j] := sum(k) A[i,k] * B[k,j] # TensorCast.jl
But this one writes its own for-loops:
@einsum C[i,j] := A[i,k] * B[k,j] # Einsum.jl
expanding out to roughly this:
T = promote_type(eltype(A), eltype(B))
C = Array{T}(undef, size(A,1), size(B,2))
@inbounds for j in 1:size(B,2)
for i in 1:size(A,1)
acc = zero(T)
for k in 1:size(A,2)
acc += A[i,k] * B[k,j]
end
C[i,j] = acc
end
end
Tullio does something similar, but working through a few functions. Taking a slightly more complicated example, this:
@tullio C[i,j] := tanh <| A[i,k] * B[k,j]
expands to roughly this:
function act!(::Type, C::AbstractArray{T}, A, B, ax_i, ax_j, ax_k, keep=nothing, final=true) where T
@inbounds @fastmath for i in ax_i
for j in ax_j
acc = isnothing(keep) ? zero(T) : C[i,j]
for k in ax_k
acc += A[i,k] * B[k,j]
end
C[i,j] = isnothing(final) ? acc : tanh(acc)
end
end
end
function make(A, B)
ax_i = axes(A,1)
ax_j = axes(B,2)
ax_k = axes(A,2) # and check this is == axes(B,1)
rhs(A,B,i,j,k) = tanh(A[i,k] * B[k,j])
T = Core.Compiler.return_type(rhs, eltype.((A,B,1,1,1))) # plus a fallback
C = similar(A, T, (ax_i, ax_j))
Tullio.threader(act!, Array{T}, C, (A,B), (ax_i,ax_j), (ax_k,), +, 64^3)
return C
end
C = Tullio.Eval(make, ∇make)(A, B)
This division allows it to dispatch to other methods of act!
: one generated with @avx
if LoopVectorization is loaded, and one for ::CuArray
if KernelAbstractions is loaded.
It also allows threader
to divide the work, calling act!
many times, from different threads, on small tiles made by dividing the longest axis (say ax_i
) in half, repeatedly. If it divides up ax_k
, these are done sequentially, with keep=true
on all ranges except the first, and final=nothing
on all except the last. But ax_i
and ax_j
are safe to do in parallel.
Finally, Eval
exists to give Zygote and friends somewhere to attach themselves. The gradient calculation looks roughly like this:
@adjoint function (e::Eval)(AB...)
C = e.fwd(AB...)
C, ΔC -> e.rev(ΔC, C, AB...)
end
function ∇act!(::Type, ΔC, ΔA, ΔB, C, A, B, ax_i, ax_j, ax_k, keep)
for k in ax_k, i in ax_i, j in ax_j
ex = ΔC[i,j] * (1-C[i,j])^2
ΔA[i,k] += ex * B[k,j]
ΔB[k,j] += A[i,k] * ex
end
end
function ∇make(ΔC, C, A, B)
ΔA = similar(A) .= 0
ΔB = similar(B) .= 0
ax_i, ax_k = axes(A); ax_j = axes(B,2)
Tullio.∇threader(∇act!, Array{T}, (ax_k,), (ax_i, ax_j), nothing)
return (ΔA, ΔB)
end
In this case, it is the loop over k
which can be safely broken among different threads, since both ΔA
and ΔB
have this index. Both ΔA
and ΔB
are filled in at once.
Notice that the derivative of y = tanh(z)
has been written in terms of y
(the final result of the forward pass) but free of z
(the result of the sum, which was not saved). If this is not possible, it will fail.
If using grad=Dual
, the gradient kernel looks different. This method cannot handle finalisers like tanh
above, but for plain @tullio C[i,j] := A[i,k] * B[k,j]
it would read:
function ∇act!(::Type, ΔC, ΔA, ΔB, C, A, B, ax_i, ax_j, ax_k, keep)
eps1 = ForwardDiff.Dual(0, (1,0))
eps2 = ForwardDiff.Dual(0, (0,1))
for k in ax_k, i in ax_i, j in ax_j
res = (A[i,k] + eps1) * (B[k,j] + eps2)
ΔA[i,k] += ForwardDiff.partials(res, 1) * ΔC[i,j]
ΔB[k,j] += ForwardDiff.partials(res, 2) * ΔC[i,j]
end
end
Writing @tullio verbose=2
will print all of these functions out.