Skip to content

Instantly share code, notes, and snippets.

@GiggleLiu
Last active November 2, 2020 17:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GiggleLiu/54c861db8062a5b02082dec3c9798a84 to your computer and use it in GitHub Desktop.
Save GiggleLiu/54c861db8062a5b02082dec3c9798a84 to your computer and use it in GitHub Desktop.
The copy A mma kernel
using CUDA
using CUDA.WMMA
using CUDA.WMMA: ColMajor, load_a, load_b, load_c, mma, store_d
using StaticArrays
a = rand(Float16, 16, 16)
b = rand(Float16, 16, 16)
c = zeros(Float16, 16, 16)
d = zeros(Float16, 16, 16)
a_dev = CuArray(a)
b_dev = CuArray(b)
c_dev = CuArray(c)
d_dev = CuArray(d)
alpha = 1.0
beta = 1.0
@inline @generated function mymma(a_frag, b_frag, c_frag, conf)
out = Expr(:tuple, [:(a_frag.x[$i]) for i=1:8]..., zeros(Float16, 8)...)
quote
i = CUDA.threadIdx().x
return Fragment{16,16,16,16,Float16,WMMA.Unspecified,WMMA.Accumulator}($out)
end
end
@eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta)
conf = Config{16, 16, 16, Float16}
a_frag = load_a(pointer(a_dev), 16, ColMajor, conf)
b_frag = load_b(pointer(b_dev), 16, ColMajor, conf)
c_frag = load_c(pointer(c_dev), 16, ColMajor, conf)
a_frag = alpha .* a_frag
c_frag = beta .* c_frag
d_frag = mymma(a_frag, b_frag, c_frag, conf)
store_d(pointer(d_dev), d_frag, 16, ColMajor, conf)
return
end
@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta)
d = Array(d_dev)
#@test all(isapprox.(alpha * a * b + beta * c, d; rtol=sqrt(eps(Float16))))
@assert d ≈ a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment