Skip to content

Instantly share code, notes, and snippets.

@theogf
Created November 5, 2021 10:10
Show Gist options
  • Save theogf/50426b2e991bba8868f6728d1325518b to your computer and use it in GitHub Desktop.
Save theogf/50426b2e991bba8868f6728d1325518b to your computer and use it in GitHub Desktop.
Test Tullio with Kernel Functions
using Tullio
using Distances
using LinearAlgebra
using BenchmarkTools
using CUDA, CUDAKernels, KernelAbstractions
using Functors
using KernelFunctions
using Test
using Functors
struct DotProduct end
(::DotProduct)(x, y) = dot(x, y)
(::SqEuclidean)(x, y) = sum(abs2, x) - 2 * dot(x, y) + sum(abs2, y)
std_pairwise(metric, x::AbstractVector, y::AbstractVector) = metric.(x, permutedims(y))
std_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = sum(abs2, x.X, dims=1) .+ sum(abs2, y.X, dims=1)' .- 2 * x.X' * y.X
std_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = x.X'y.X
Distances.pairwise(m::DotProduct, x, y) = std_pairwise(m, x, y) # Since it's not a correct metric we use the std fallback
tullio_pairwise(metric, x::AbstractVector, y::AbstractVector) = @tullio K[i, j] := metric(x[i], y[j])
tullio_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2
tullio_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] * y.X[k, j]
D = 20
N = 1000
Xmat = rand(D, N)
Xcol = ColVecs(Xmat)
Xvec = collect.(eachcol(Xmat))
gpuX = CUDA.rand(D, N)
gpuXcol = ColVecs(gpuX)
gpuXvec = cu.(Xvec)
Ymat = rand(D, N)
Ycol = ColVecs(Ymat)
Yvec = collect.(eachcol(Ymat))
gpuY = CUDA.rand(D, N)
gpuYcol = ColVecs(gpuY)
gpuYvec = cu.(Yvec)
@testset "Test correct implementation" begin
for metric in (SqEuclidean(), DotProduct())
@testset "$(string(metric))" begin
for (X, Y) in zip((Xcol, Xvec), (Ycol, Yvec))
@test pairwise(metric, X, Y) ≈ std_pairwise(metric, X, Y)
@test pairwise(metric, X, Y) ≈ tullio_pairwise(metric, X, Y)
end
end
end
end
CUDA.allowscalar(false) # Disallow any scalar operations
@testset "GPU" begin
for metric in (SqEuclidean(), DotProduct())
@testset "$(string(metric))" begin
for (X, Y) in zip((gpuXcol, ), (gpuYcol, ))
@test_nowarn std_pairwise(metric, X, Y)
@test_nowarn tullio_pairwise(metric, X, Y)
# @test_nowarn pairwise(metric, X, Y)
end
end
end
end
results_benchmark = Dict()
for metric in [SqEuclidean(), DotProduct()]
results_benchmark[metric] = Dict()
for type in [:col, :vec]
results_benchmark[metric][type] = Dict()
X = eval(Meta.parse("X$type"))
Y = eval(Meta.parse("Y$type"))
results_benchmark[metric][type][:b_dist] = @benchmark pairwise($metric, $X, $Y)
results_benchmark[metric][type][:b_std] = @benchmark std_pairwise($metric, $X, $Y)
results_benchmark[metric][type][:b_tullio] = @benchmark tullio_pairwise($metric, $X, $Y)
end
end
# Prettyyy printing
for metric in [SqEuclidean(), DotProduct()]
println("Testing metric $(metric)")
for type in [:col, :vec]
println("Testing X$type")
for method in [:b_dist, :b_std, :b_tullio]
println("Method $(method)")
display(results_benchmark[SqEuclidean()][type][method])
end
println()
end
println()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment