Skip to content

Instantly share code, notes, and snippets.

@xrq-phys
Last active November 28, 2023 17:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xrq-phys/8b6e52f0f371acb0244950f755d7476f to your computer and use it in GitHub Desktop.
Save xrq-phys/8b6e52f0f371acb0244950f755d7476f to your computer and use it in GitHub Desktop.
Julia Implementation of NumPy's Tensordot Functon, Compatible with Flux/Zygote's Automatic Differentiation
"""
Tensordot.jl - Minimal Tensordot Implementation
Minimal tensordot for supporting Zygote's automatic differentiation.
This bunch of code is also compatible with FluxML's Tracker module.
"""
module Tensordot
using Zygote: @adjoint
using LinearAlgebra
"""
contract(T1, T2, axesL, axesR)
Contracts tensors 'T1' with 'T2' in axes specified in axesL and axesR.
Interface is somehow the same as np.tensordot(T1, T2, axes=(axesL, axesR)).
"""
contract(TL, TR, axesL::Array{Int}, axesR::Array{Int}) = begin
# Gets shape and axes information from helper.
shapeInfo = contractprep(size(TL), size(TR), axesL, axesR)
# Apply transformation
contractraw(TL, TR, shapeInfo...)
end # contract
# Raw contraction function.
contractraw(TL, TR, permL::Array{Int}, permR::Array{Int},
shapeL::Tuple, shapeR::Tuple, extL::Array, extR::Array) = begin
# Multiply and restore to original shape.
reshape((reshape(permutedims(TL, permL), shapeL) *
reshape(permutedims(TR, permR), shapeR)), (extL..., extR...))
end # contractraw
"""
prepcontract(T1, T2, axesL, axesR)
Index preparations for contracting.
"""
contractprep(shapeL::Tuple, shapeR::Tuple, axesL::Array{Int}, axesR::Array{Int}) = begin
shapeL = [shapeL...]
shapeR = [shapeR...]
# TODO: Check axes boundary.
# Dumb index permuting & size extraction
dumbperm(shape::Array{Int}, pick::Array{Int}) = begin
sbarrier = sort(pick) .+ 1
ebarrier = sort(pick) .- 1
sbarrier = vcat([1], sbarrier)
ebarrier = vcat(ebarrier, length(shape))
regular = Int[]
for i = 1:length(sbarrier)
append!(regular, sbarrier[i]:ebarrier[i])
end # for
return regular
end # dumbperm
# External permutation
permL = dumbperm(shapeL, axesL)
permR = dumbperm(shapeR, axesR)
# External shape
extL = [shapeL[i] for i in permL]
extR = [shapeR[i] for i in permR]
# Contractional permutation
append!(permL, axesL)
prepend!(permR, axesR)
outerL = if (length(extL)==0) 1 else reduce(*, extL) end
outerR = if (length(extR)==0) 1 else reduce(*, extR) end
innerL = reduce(*, [shapeL[i] for i in axesL])
innerR = reduce(*, [shapeR[i] for i in axesR])
permL, permR, (outerL, innerL), (innerR, outerR), extL, extR
end # contractprep
# Adjoint of contract should refrain from digging into index processing.
@adjoint contractprep(shapeL::Tuple, shapeR::Tuple, axesL::Array{Int}, axesR::Array{Int}) = begin
contractprep(shapeL, shapeR, axesL, axesR), _ -> nothing
end # @adjoint contractprep
end # module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment