Skip to content

Instantly share code, notes, and snippets.

@johnbcoughlin
Created May 22, 2024 14:23
Show Gist options
  • Save johnbcoughlin/aad1b552ee6697bdd112921c2095d2cf to your computer and use it in GitHub Desktop.
Save johnbcoughlin/aad1b552ee6697bdd112921c2095d2cf to your computer and use it in GitHub Desktop.
using LinearAlgebra
function test_impls()
d = 10
n = 100
du = zeros(d, n)
u = rand(d, n)
s = rand(d, n)
@time landau_f_aux!(du, u, s, Val(10); γ=-3)
du1 = copy(du)
du .= 0
#@time reference_impl!(du, u, s; γ=-3)
du2 = copy(du)
du .= 0
display(du1)
#display(du2)
# Allocate scratch space
normz_pow_γ = zeros(n, n)
Z = zeros(d, n, n)
V = zeros(d, n, n)
ZV = zeros(d, n, n)
#@time matrix_impl!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ)
#display(du)
@time matrix_impl2!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ)
display(du)
end
@generated function landau_f_aux!(du, u, s, ::Val{d}; γ) where {d}
quote
ε = eps(eltype(du))
du .= 0
n = size(u, 2)
for p = 1:n
Base.Cartesian.@nexprs $d i -> dx_i = zero(eltype(du))
for q = 1:n
dotzv = zero(eltype(du))
normsqz = zero(eltype(du))
Base.Cartesian.@nexprs $d i -> begin
z_i = u[i, p] - u[i, q]
v_i = s[i, q] - s[i, p]
dotzv += z_i * v_i
normsqz += z_i * z_i
end
normz_pow_γ = 1/(sqrt(normsqz) + ε)^(-γ)
Base.Cartesian.@nexprs $d i -> begin
dx_i += (v_i * normsqz - dotzv * z_i) * normz_pow_γ
end
end
Base.Cartesian.@nexprs $d i -> begin
du[i, p] += dx_i
end
end
nothing
end
end
function reference_impl!(du, u, s; γ=-3)
d = size(u, 1)
n = size(u, 2)
ε = eps(eltype(du))
dx = zeros(d, n)
for p in 1:n
for q in 1:n
dotzv = 0.0
normsqz = 0.0
for i in 1:d
z_i = u[i, p] - u[i, q]
v_i = s[i, q] - s[i, p]
dotzv += z_i * v_i
normsqz += z_i * z_i
end
normz_pow_γ = 1/(sqrt(normsqz) + ε)^(-γ)
for i in 1:d
z_i = u[i, p] - u[i, q]
v_i = s[i, q] - s[i, p]
dx[i, p] += (v_i * normsqz - dotzv * z_i) * normz_pow_γ
end
end
for i in 1:d
du[i, p] += dx[i, p]
end
dx .= 0
end
end
function matrix_impl!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ)
d = size(u, 1)
n = size(u, 2)
ϵ = eps(eltype(du))
uT = reshape(u, (d, 1, n))
sT = reshape(s, (d, 1, n))
# Indices are (i, p, q)
@. Z = u - uT
@. V = sT - s
@. ZV = Z * V
dotzv = reshape(sum(ZV, dims=1), (n, n))
normsqz = reshape(sum(abs2, Z, dims=1), (n, n))
@. normz_pow_γ = 1/(sqrt(normsqz) + ϵ)^(-γ)
@. dotzv = dotzv * normz_pow_γ
@. normsqz = normsqz * normz_pow_γ
for p in 1:n
# The final assembly is a contraction over the q index
mul!((@view du[:, p]), (@view V[:, p, :]), (@view normsqz[p, :]))
mul!((@view du[:, p]), (@view Z[:, p, :]), (@view dotzv[p, :]), -1, 1)
end
end
function matrix_impl2!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ)
d = size(u, 1)
n = size(u, 2)
ϵ = eps(eltype(du))
uT = reshape(u, (d, 1, n))
sT = reshape(s, (d, 1, n))
u_dot_u = u' * u
u_norm_sq = diag(u_dot_u)
s_dot_s = s' * s
s_norm_sq = diag(s_dot_s)
u_dot_s = u' * s
u_dot_s_diag = diag(u_dot_s)
# normsqz = |u_p - u_q|^2
# = |u_p|^2 + |u_q|^2 - 2*u_p ⋅ u_q
normsqz = @. u_norm_sq + u_norm_sq' - 2*u_dot_u
# dotzv = (u_p - u_q) ⋅ (s_q - s_p)
# = (u_p ⋅ s_q) - (u_q ⋅ s_q) - (u_p ⋅ s_p) + (u_q ⋅ s_p)
dotzv = @. u_dot_s + u_dot_s' - u_dot_s_diag - u_dot_s_diag'
# Indices are (i, p, q)
@. Z = u - uT
@. V = sT - s
@. normz_pow_γ = 1/(sqrt(normsqz) + ϵ)^(-γ)
@. dotzv = dotzv * normz_pow_γ
@. normsqz = normsqz * normz_pow_γ
# The final assembly is a contraction over the q index, keeping p fixed.
# This can be accomplished with a fast gemv_batched kernel call:
# https://docs.nvidia.com/cuda/cublas/#cublas-t-gemvbatched
# https://fluxml.ai/NNlib.jl/dev/reference/#NNlib.batched_mul
for p in 1:n
mul!((@view du[:, p]), (@view V[:, p, :]), (@view normsqz[p, :]))
mul!((@view du[:, p]), (@view Z[:, p, :]), (@view dotzv[p, :]), -1, 1)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment