Created
May 22, 2024 14:23
-
-
Save johnbcoughlin/aad1b552ee6697bdd112921c2095d2cf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
function test_impls() | |
d = 10 | |
n = 100 | |
du = zeros(d, n) | |
u = rand(d, n) | |
s = rand(d, n) | |
@time landau_f_aux!(du, u, s, Val(10); γ=-3) | |
du1 = copy(du) | |
du .= 0 | |
#@time reference_impl!(du, u, s; γ=-3) | |
du2 = copy(du) | |
du .= 0 | |
display(du1) | |
#display(du2) | |
# Allocate scratch space | |
normz_pow_γ = zeros(n, n) | |
Z = zeros(d, n, n) | |
V = zeros(d, n, n) | |
ZV = zeros(d, n, n) | |
#@time matrix_impl!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ) | |
#display(du) | |
@time matrix_impl2!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ) | |
display(du) | |
end | |
@generated function landau_f_aux!(du, u, s, ::Val{d}; γ) where {d} | |
quote | |
ε = eps(eltype(du)) | |
du .= 0 | |
n = size(u, 2) | |
for p = 1:n | |
Base.Cartesian.@nexprs $d i -> dx_i = zero(eltype(du)) | |
for q = 1:n | |
dotzv = zero(eltype(du)) | |
normsqz = zero(eltype(du)) | |
Base.Cartesian.@nexprs $d i -> begin | |
z_i = u[i, p] - u[i, q] | |
v_i = s[i, q] - s[i, p] | |
dotzv += z_i * v_i | |
normsqz += z_i * z_i | |
end | |
normz_pow_γ = 1/(sqrt(normsqz) + ε)^(-γ) | |
Base.Cartesian.@nexprs $d i -> begin | |
dx_i += (v_i * normsqz - dotzv * z_i) * normz_pow_γ | |
end | |
end | |
Base.Cartesian.@nexprs $d i -> begin | |
du[i, p] += dx_i | |
end | |
end | |
nothing | |
end | |
end | |
function reference_impl!(du, u, s; γ=-3) | |
d = size(u, 1) | |
n = size(u, 2) | |
ε = eps(eltype(du)) | |
dx = zeros(d, n) | |
for p in 1:n | |
for q in 1:n | |
dotzv = 0.0 | |
normsqz = 0.0 | |
for i in 1:d | |
z_i = u[i, p] - u[i, q] | |
v_i = s[i, q] - s[i, p] | |
dotzv += z_i * v_i | |
normsqz += z_i * z_i | |
end | |
normz_pow_γ = 1/(sqrt(normsqz) + ε)^(-γ) | |
for i in 1:d | |
z_i = u[i, p] - u[i, q] | |
v_i = s[i, q] - s[i, p] | |
dx[i, p] += (v_i * normsqz - dotzv * z_i) * normz_pow_γ | |
end | |
end | |
for i in 1:d | |
du[i, p] += dx[i, p] | |
end | |
dx .= 0 | |
end | |
end | |
function matrix_impl!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ) | |
d = size(u, 1) | |
n = size(u, 2) | |
ϵ = eps(eltype(du)) | |
uT = reshape(u, (d, 1, n)) | |
sT = reshape(s, (d, 1, n)) | |
# Indices are (i, p, q) | |
@. Z = u - uT | |
@. V = sT - s | |
@. ZV = Z * V | |
dotzv = reshape(sum(ZV, dims=1), (n, n)) | |
normsqz = reshape(sum(abs2, Z, dims=1), (n, n)) | |
@. normz_pow_γ = 1/(sqrt(normsqz) + ϵ)^(-γ) | |
@. dotzv = dotzv * normz_pow_γ | |
@. normsqz = normsqz * normz_pow_γ | |
for p in 1:n | |
# The final assembly is a contraction over the q index | |
mul!((@view du[:, p]), (@view V[:, p, :]), (@view normsqz[p, :])) | |
mul!((@view du[:, p]), (@view Z[:, p, :]), (@view dotzv[p, :]), -1, 1) | |
end | |
end | |
function matrix_impl2!(du, u, s; γ=-3, Z, V, ZV, normz_pow_γ) | |
d = size(u, 1) | |
n = size(u, 2) | |
ϵ = eps(eltype(du)) | |
uT = reshape(u, (d, 1, n)) | |
sT = reshape(s, (d, 1, n)) | |
u_dot_u = u' * u | |
u_norm_sq = diag(u_dot_u) | |
s_dot_s = s' * s | |
s_norm_sq = diag(s_dot_s) | |
u_dot_s = u' * s | |
u_dot_s_diag = diag(u_dot_s) | |
# normsqz = |u_p - u_q|^2 | |
# = |u_p|^2 + |u_q|^2 - 2*u_p ⋅ u_q | |
normsqz = @. u_norm_sq + u_norm_sq' - 2*u_dot_u | |
# dotzv = (u_p - u_q) ⋅ (s_q - s_p) | |
# = (u_p ⋅ s_q) - (u_q ⋅ s_q) - (u_p ⋅ s_p) + (u_q ⋅ s_p) | |
dotzv = @. u_dot_s + u_dot_s' - u_dot_s_diag - u_dot_s_diag' | |
# Indices are (i, p, q) | |
@. Z = u - uT | |
@. V = sT - s | |
@. normz_pow_γ = 1/(sqrt(normsqz) + ϵ)^(-γ) | |
@. dotzv = dotzv * normz_pow_γ | |
@. normsqz = normsqz * normz_pow_γ | |
# The final assembly is a contraction over the q index, keeping p fixed. | |
# This can be accomplished with a fast gemv_batched kernel call: | |
# https://docs.nvidia.com/cuda/cublas/#cublas-t-gemvbatched | |
# https://fluxml.ai/NNlib.jl/dev/reference/#NNlib.batched_mul | |
for p in 1:n | |
mul!((@view du[:, p]), (@view V[:, p, :]), (@view normsqz[p, :])) | |
mul!((@view du[:, p]), (@view Z[:, p, :]), (@view dotzv[p, :]), -1, 1) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment