Skip to content

Instantly share code, notes, and snippets.

@mcabbott
Created August 15, 2020 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mcabbott/7494c8cc46a906fe88172dbfc8b92ad2 to your computer and use it in GitHub Desktop.
Save mcabbott/7494c8cc46a906fe88172dbfc8b92ad2 to your computer and use it in GitHub Desktop.

Internals

Take again matrix multiplication. The following three macros all end up calling the same functions as does C = A * B:

@tensor C[i,j] := A[i,k] * B[k,j]         # TensorOperations.jl
@ein C[i,j] := A[i,k] * B[k,j]            # OMEinsum.jl
@matmul C[i,j] := sum(k) A[i,k] * B[k,j]  # TensorCast.jl

But this one writes its own for-loops:

@einsum C[i,j] := A[i,k] * B[k,j]         # Einsum.jl

expanding out to roughly this:

T = promote_type(eltype(A), eltype(B))
C = Array{T}(undef, size(A,1), size(B,2))
@inbounds for j in 1:size(B,2)
    for i in 1:size(A,1)
        acc = zero(T)
        for k in 1:size(A,2)
            acc += A[i,k] * B[k,j]
        end
        C[i,j] = acc
    end
end

Tullio does something similar, but working through a few functions. Taking a slightly more complicated example, this:

@tullio C[i,j] := tanh <| A[i,k] * B[k,j]

expands to roughly this:

function act!(::Type, C::AbstractArray{T}, A, B, ax_i, ax_j, ax_k, keep=nothing, final=true) where T
    @inbounds @fastmath for i in ax_i
        for j in ax_j
            acc = isnothing(keep) ? zero(T) : C[i,j]
            for k in ax_k
                acc += A[i,k] * B[k,j]
            end
            C[i,j] = isnothing(final) ? acc : tanh(acc)
        end
    end
end

function make(A, B)
    ax_i = axes(A,1)
    ax_j = axes(B,2)
    ax_k = axes(A,2) # and check this is == axes(B,1)
    rhs(A,B,i,j,k) = tanh(A[i,k] * B[k,j])
    T = Core.Compiler.return_type(rhs, eltype.((A,B,1,1,1))) # plus a fallback
    C = similar(A, T, (ax_i, ax_j))
    Tullio.threader(act!, Array{T}, C, (A,B), (ax_i,ax_j), (ax_k,), +, 64^3)
    return C
end

C = Tullio.Eval(make, ∇make)(A, B)

This division allows it to dispatch to other methods of act!: one generated with @avx if LoopVectorization is loaded, and one for ::CuArray if KernelAbstractions is loaded.

It also allows threader to divide the work, calling act! many times, from different threads, on small tiles made by dividing the longest axis (say ax_i) in half, repeatedly. If it divides up ax_k, these are done sequentially, with keep=true on all ranges except the first, and final=nothing on all except the last. But ax_i and ax_j are safe to do in parallel.

Finally, Eval exists to give Zygote and friends somewhere to attach themselves. The gradient calculation looks roughly like this:

@adjoint function (e::Eval)(AB...)
    C = e.fwd(AB...)
    C, ΔC -> e.rev(ΔC, C, AB...)
end

function ∇act!(::Type, ΔC, ΔA, ΔB, C, A, B, ax_i, ax_j, ax_k, keep)
    for k in ax_k, i in ax_i, j in ax_j
        ex = ΔC[i,j] * (1-C[i,j])^2
        ΔA[i,k] += ex * B[k,j]
        ΔB[k,j] += A[i,k] * ex
    end
end

function ∇make(ΔC, C, A, B)
    ΔA = similar(A) .= 0
    ΔB = similar(B) .= 0
    ax_i, ax_k = axes(A); ax_j = axes(B,2)
    Tullio.∇threader(∇act!, Array{T}, (ax_k,), (ax_i, ax_j), nothing)
    return (ΔA, ΔB)
end

In this case, it is the loop over k which can be safely broken among different threads, since both ΔA and ΔB have this index. Both ΔA and ΔB are filled in at once.

Notice that the derivative of y = tanh(z) has been written in terms of y (the final result of the forward pass) but free of z (the result of the sum, which was not saved). If this is not possible, it will fail.

If using grad=Dual, the gradient kernel looks different. This method cannot handle finalisers like tanh above, but for plain @tullio C[i,j] := A[i,k] * B[k,j] it would read:

function ∇act!(::Type, ΔC, ΔA, ΔB, C, A, B, ax_i, ax_j, ax_k, keep)
    eps1 = ForwardDiff.Dual(0, (1,0))
    eps2 = ForwardDiff.Dual(0, (0,1))
    for k in ax_k, i in ax_i, j in ax_j
        res = (A[i,k] + eps1) * (B[k,j] + eps2)
        ΔA[i,k] += ForwardDiff.partials(res, 1) * ΔC[i,j]
        ΔB[k,j] += ForwardDiff.partials(res, 2) * ΔC[i,j]
    end
end

Writing @tullio verbose=2 will print all of these functions out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment