Skip to content

Instantly share code, notes, and snippets.

@shashi
Last active August 19, 2020 22:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shashi/c4569a4f1d1d2c1bf546080902cfa706 to your computer and use it in GitHub Desktop.
Save shashi/c4569a4f1d1d2c1bf546080902cfa706 to your computer and use it in GitHub Desktop.
distributed matmul using Tullio.jl
using Tullio
using Dagger
using Dagger: DArray, chunks, Thunk, ArrayDomain
struct Zero end
Base.zero(::Type{Thunk}) = delayed(()->Zero())()
Base.zero(::Type{Union{Zero, ArrayDomain}}) = Zero()
Base.:(+)(z::Zero, x) = x
Base.:(+)(x, z::Zero) = x
Base.:(+)(z::Zero, x::Zero) = z
function matmul(A::DArray, B::DArray)
# Make the output domains
dA = A.subdomains
dB = B.subdomains
dC = Union{Zero, ArrayDomain}[Zero() for i=1:size(dA, 1), j=1:size(dB, 2)]
combine(da, db) = ArrayDomain(da.indexes[1], db.indexes[2])
@tullio dC[i,j] = combine(dA[i,k], dB[k,j]) threads=false
# Make the output chunks
cA = chunks(A)
cB = chunks(B)
on_each = delayed((a,b) -> @tullio(c[i,j] := a[i, k] * b[k, j]))
plus = (x,y) -> delayed(+)(x,y)
cC = [zero(Thunk) for i=1:size(cA, 1), j=1:size(cB, 2)]
@tullio (plus) (cC[i,j] = on_each(cA[i,k], cB[k,j])) threads=false
# Make the output DArray
DArray(promote_type(eltype(A), eltype(B)),
combine(A.domain, B.domain), map(identity, dC), cC, cat)
end
A = compute(rand(Blocks(1000,100), 10000,1000));
D = compute(rand(Blocks(100,100), 1000,100));
matmul(A, D)
using Test
@test collect(A * D) ≈ collect(matmul(A, D))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment