Skip to content

Instantly share code, notes, and snippets.

@maleadt
Created June 6, 2019 22:20
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maleadt/1ec91b3b12ede9898958c95596cabe8b to your computer and use it in GitHub Desktop.
Save maleadt/1ec91b3b12ede9898958c95596cabe8b to your computer and use it in GitHub Desktop.
Tridiagonal matrix algorithm on the GPU with Julia
# experimentation with batched tridiagonal solvers on the GPU for Oceananigans.jl
#
# - reference serial CPU implementation
# - batched GPU implementation using cuSPARSE (fastest)
# - batched GPU implementation based on the serial CPU implementation (slow but flexible)
# - parallel GPU implementation (potentially fast and flexible)
#
# see `test_batched` and `bench_batched`
using CUDAdrv
using CuArrays
using CUDAnative
using LinearAlgebra
# reference CPU implementation per Numerical Recipes, Press et. al 1992 (sec 2.4)
function tridag(M::Tridiagonal{T,<:Array}, rhs::Vector{T})::Vector{T} where T
N = length(rhs)
phi = similar(rhs)
gamma = similar(rhs)
beta = M.d[1]
phi[1] = rhs[1] / beta
for j=2:N
gamma[j] = M.du[j-1] / beta
beta = M.d[j]-M.dl[j-1]*gamma[j]
if abs(beta) < 1.e-12
# This should only happen on last element of forward pass for problems
# with zero eigenvalue. In that case the algorithmn is still stable.
break
end
phi[j] = (rhs[j]-M.dl[j-1]*phi[j-1])/beta
end
for j=1:N-1
k = N-j
phi[k] = phi[k]-gamma[k+1]*phi[k+1]
end
return phi
end
# batched GPU implementation running the serial algorithm on each thread
function batched_tridiag!(a::CuVector{T}, b::CuVector{T}, c::CuVector{T}, d::CuVector{T},
x::CuVector{T}, batchCount::Integer, batchStride::Integer) where T
N = length(d) ÷ batchCount
@assert length(a) == length(b) == length(c) == N*batchCount
workspace = similar(d)
function kernel(a, b, c, d, x, workspace)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
@inbounds if i <= batchCount
offset = (i-1) * batchStride
beta = b[1+offset]
x[1+offset] = d[1+offset] / beta
for j = 2:N
k = j+offset
workspace[k] = c[k-1] / beta
beta = b[k] - a[k]*workspace[k]
if abs(beta) < 1.e-12
# This should only happen on last element of forward pass for problems
# with zero eigenvalue. In that case the algorithmn is still stable.
break
end
x[k] = (d[k]-a[k]*x[k-1])/beta
end
for j = 1:N-1
k = N-j+offset
x[k] = x[k]-workspace[k+1]*x[k+1]
end
end
return
end
function get_config(kernel)
fun = kernel.fun
config = launch_configuration(fun)
# round up to cover all batches
blocks = (N*batchCount + config.threads - 1) ÷ config.threads
return (threads=config.threads, blocks=blocks)
end
@cuda config=get_config kernel(a, b, c, d, x, workspace)
CuArrays.unsafe_free!(workspace)
return
end
function batched_tridiag(a, b, c, d, batchCount, batchStride)
x = similar(d)
batched_tridiag!(a, b, c, d, x, batchCount, batchStride)
return x
end
# parallel tridiagonal solver using cyclic reduction
# FIXME: doesn't work on non-pow2 sized inputs
function batched_parallel_tridiag!(d_a::CuVector{T}, d_b::CuVector{T}, d_c::CuVector{T}, d_d::CuVector{T},
d_x::CuVector{T}, batchCount::Integer, batchStride::Integer) where {T}
N = length(d_d) ÷ batchCount
@assert ispow2(N)
@assert length(d_a) == length(d_b) == length(d_c) == N*batchCount
function kernel(d_a, d_b, d_c, d_d, d_x)
thid = threadIdx().x
blid = blockIdx().x
numThreads = blockDim().x
iterations = floor(Int, CUDAnative.log2(Float32(N ÷ 2)))
@inbounds begin
# load data into shared memory
a = @cuDynamicSharedMem(T, (N,))
b = @cuDynamicSharedMem(T, (N,), N*sizeof(T))
c = @cuDynamicSharedMem(T, (N,), N*sizeof(T)*2)
d = @cuDynamicSharedMem(T, (N,), N*sizeof(T)*3)
x = @cuDynamicSharedMem(T, (N,), N*sizeof(T)*4)
a[thid] = d_a[thid + (blid-1) * batchStride]
a[thid + blockDim().x] = d_a[thid + blockDim().x + (blid-1) * batchStride]
b[thid] = d_b[thid + (blid-1) * batchStride]
b[thid + blockDim().x] = d_b[thid + blockDim().x + (blid-1) * batchStride]
c[thid] = d_c[thid + (blid-1) * batchStride]
c[thid + blockDim().x] = d_c[thid + blockDim().x + (blid-1) * batchStride]
d[thid] = d_d[thid + (blid-1) * batchStride]
d[thid + blockDim().x] = d_d[thid + blockDim().x + (blid-1) * batchStride]
sync_threads()
# forward elimination
stride = 1
for j = 1:iterations
sync_threads()
stride *= 2
delta = stride ÷ 2
if threadIdx().x <= numThreads
i = stride * (threadIdx().x - 1) + stride
iLeft = i - delta
iRight = i + delta
if iRight > N
iRight = N
end
tmp1 = a[i] / b[iLeft]
tmp2 = c[i] / b[iRight]
b[i] = b[i] - c[iLeft] * tmp1 - a[iRight] * tmp2
d[i] = d[i] - d[iLeft] * tmp1 - d[iRight] * tmp2
a[i] = -a[iLeft] * tmp1
c[i] = -c[iRight] * tmp2
end
numThreads ÷= 2
end
if thid <= 2
addr1 = stride;
addr2 = 2 * stride;
tmp3 = b[addr2]*b[addr1] - c[addr1]*a[addr2]
x[addr1] = (b[addr2]*d[addr1]-c[addr1]*d[addr2])/tmp3
x[addr2] = (d[addr2]*b[addr1]-d[addr1]*a[addr2])/tmp3
end
# backward substitution
numThreads = 2
for j = 1:iterations
delta = stride ÷ 2
sync_threads()
if thid <= numThreads
i = stride * (thid - 1) + stride ÷ 2
if i == delta
x[i] = (d[i] - c[i]*x[i+delta]) / b[i]
else
x[i] = (d[i] - a[i]*x[i-delta] - c[i]*x[i+delta]) / b[i]
end
end
stride ÷= 2
numThreads *= 2
end
sync_threads()
# write back to global memory
d_x[thid + (blid-1) * batchStride] = x[thid]
d_x[thid + blockDim().x + (blid-1) * batchStride] = x[thid + blockDim().x]
end
return
end
threads = N ÷ 2
shmem = 5 * sizeof(T) * N
@cuda blocks=batchCount threads=threads shmem=shmem kernel(d_a, d_b, d_c, d_d, d_x)
return
end
function batched_parallel_tridiag(a, b, c, d, batchCount, batchStride)
x = similar(d)
batched_parallel_tridiag!(a, b, c, d, x, batchCount, batchStride)
return x
end
# test and benchmark code
using Test
function test_single()
# allocate data
N = 256
t = Tridiagonal(rand(N,N))
rhs = rand(N)
# solve
phi = tridag(t, rhs)
@test t * phi ≈ rhs
end
using CuArrays
using CuArrays.CUSPARSE
function test_batched()
# problem definition (weird sizes to flush out issues)
X = 7
Y = 13
Z = 64
T = Float64
# allocate data
lower = CuArrays.rand(T, Z, Y, X)
lower[1, :, :] = 0
upper = CuArrays.rand(T, Z, Y, X)
upper[Z, :, :] = 0
middle = CuArrays.rand(T, Z, Y, X)
rhs = CuArrays.rand(T, Z, Y, X)
# batched interface uses 1d vectors
flat_upper = reshape(upper, X*Y*Z)
flat_middle = reshape(middle, X*Y*Z)
flat_lower = reshape(lower, X*Y*Z)
flat_rhs = reshape(rhs, X*Y*Z)
for f in (CUSPARSE.gtsvStridedBatch, batched_tridiag, batched_parallel_tridiag)
flat_out = f(flat_lower, flat_middle, flat_upper, flat_rhs, X*Y, Z)
out = reshape(flat_out, Z, Y, X)
# verify
for x in 1:X, y in 1:Y
a = Array(lower[:, y, x])
b = Array(middle[:, y, x])
c = Array(upper[:, y, x])
u = Array(out[:, y, x])
d = Array(rhs[:, y, x])
t = Tridiagonal(a[2:end], b, c[1:end-1])
@test t * u ≈ d
break
end
end
return
end
using BenchmarkTools
using Statistics
function bench_batched()
# problem definition
X = 256
Y = 256
Z = 256
T = Float32
# allocate data
lower = CuArrays.rand(T, Z, Y, X)
lower[1, :, :] = 0
upper = CuArrays.rand(T, Z, Y, X)
upper[Z, :, :] = 0
middle = CuArrays.rand(T, Z, Y, X)
rhs = CuArrays.rand(T, Z, Y, X)
# batched interface uses 1d vectors
flat_upper = reshape(upper, X*Y*Z)
flat_middle = reshape(middle, X*Y*Z)
flat_lower = reshape(lower, X*Y*Z)
flat_rhs = reshape(rhs, X*Y*Z)
suite = BenchmarkGroup()
suite["cuSPARSE"] = @benchmarkable begin
CuArrays.@sync CUSPARSE.gtsvStridedBatch($flat_lower, $flat_middle, $flat_upper, $flat_rhs, $X*$Y, $Z)
end setup=(GC.gc())
suite["serial"] = @benchmarkable begin
CuArrays.@sync batched_tridiag($flat_lower, $flat_middle, $flat_upper, $flat_rhs, $X*$Y, $Z)
end setup=(GC.gc())
suite["parallel"] = @benchmarkable begin
CuArrays.@sync batched_parallel_tridiag($flat_lower, $flat_middle, $flat_upper, $flat_rhs, $X*$Y, $Z)
end setup=(GC.gc())
warmup(suite)
@show results = run(suite)
judge(median(results["parallel"]), median(results["cuSPARSE"]))
end
@ali-ramadhan
Copy link

Just posting your benchmarks from the GPU hackathon for reference:

image

@rveltz
Copy link

rveltz commented Dec 14, 2020

Hi,

I am wondering if you have a version of the PCR that is threaded (CPU) rather than GPU?

Thank you

@ali-ramadhan
Copy link

@rveltz Not sure if such a version exists but it should be possible to write one using KernelAbstractions.jl: you write one kernel that runs multi-threaded on a CPU and also works on CUDA GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment