Skip to content

Instantly share code, notes, and snippets.

@xrq-phys
Last active July 25, 2020 15:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xrq-phys/a90e3692a235f7c61c985df4039f6ff1 to your computer and use it in GitHub Desktop.
Save xrq-phys/a90e3692a235f7c61c985df4039f6ff1 to your computer and use it in GitHub Desktop.
Prototype for TBLIS contract for ForwardDiff.jl
# here idx has 3 entries corresponding to e.g. "ik,jk->ij".
contract!(A::Array{T},
B::Array{T},
C::Array{T},
idx) where {T<:Dual} = contract!(T, sizeof(T)/sizeof(tovalue(T)), # tovalue unveils base type.
A, 0, B, 0, C, 0, idx)
contract!(Type::Dual{Tg, T, ND}, topst, # top-level stride
A::Array, sftA, # arrays here are all at their top-level (not dispatched)
B::Array, sftB,
C::Array, sftC,
idx) = begin
# direct dispatch for value types.
contract!(T, topst, A, sftA, B, sftB, C, sftC, idx)
# for all differentials
# TODO: consider exchangability
for id = 1:ND
# unpacks one layer of dual. note that differentials are also in value's type.
contract!(T, topst,
A, sftA + id*sizeof(T),
B, sftB,
C, sftC + id*sizeof(T),
idx)
contract!(T, topst,
A, sftA,
B, sftB + id*sizeof(T),
C, sftC + id*sizeof(T),
idx)
end
end
contract!(Type::ValueType, # to be defined
topst,
A::Ptr{Cvoid}, sftA,
B::Ptr{Cvoid}, sftB,
C::Ptr{Cvoid}, sftC,
idx) = begin
# convert stride unit in top duals.
stA = topst .* strides(A)
stB = topst .* strides(B)
stC = topst .* strides(C)
# - build TBLIS-object for A, B and C from bare memory shifted by sft{A,B,C}.
# - TLIBS-contract according to idx.
return nothing
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment