Created
June 6, 2019 22:20
-
-
Save maleadt/1ec91b3b12ede9898958c95596cabe8b to your computer and use it in GitHub Desktop.
Tridiagonal matrix algorithm on the GPU with Julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# experimentation with batched tridiagonal solvers on the GPU for Oceananigans.jl | |
# | |
# - reference serial CPU implementation | |
# - batched GPU implementation using cuSPARSE (fastest) | |
# - batched GPU implementation based on the serial CPU implementation (slow but flexible) | |
# - parallel GPU implementation (potentially fast and flexible) | |
# | |
# see `test_batched` and `bench_batched` | |
using CUDAdrv | |
using CuArrays | |
using CUDAnative | |
using LinearAlgebra | |
# reference CPU implementation per Numerical Recipes, Press et. al 1992 (sec 2.4) | |
function tridag(M::Tridiagonal{T,<:Array}, rhs::Vector{T})::Vector{T} where T | |
N = length(rhs) | |
phi = similar(rhs) | |
gamma = similar(rhs) | |
beta = M.d[1] | |
phi[1] = rhs[1] / beta | |
for j=2:N | |
gamma[j] = M.du[j-1] / beta | |
beta = M.d[j]-M.dl[j-1]*gamma[j] | |
if abs(beta) < 1.e-12 | |
# This should only happen on last element of forward pass for problems | |
# with zero eigenvalue. In that case the algorithmn is still stable. | |
break | |
end | |
phi[j] = (rhs[j]-M.dl[j-1]*phi[j-1])/beta | |
end | |
for j=1:N-1 | |
k = N-j | |
phi[k] = phi[k]-gamma[k+1]*phi[k+1] | |
end | |
return phi | |
end | |
# batched GPU implementation running the serial algorithm on each thread | |
function batched_tridiag!(a::CuVector{T}, b::CuVector{T}, c::CuVector{T}, d::CuVector{T}, | |
x::CuVector{T}, batchCount::Integer, batchStride::Integer) where T | |
N = length(d) ÷ batchCount | |
@assert length(a) == length(b) == length(c) == N*batchCount | |
workspace = similar(d) | |
function kernel(a, b, c, d, x, workspace) | |
i = (blockIdx().x-1) * blockDim().x + threadIdx().x | |
@inbounds if i <= batchCount | |
offset = (i-1) * batchStride | |
beta = b[1+offset] | |
x[1+offset] = d[1+offset] / beta | |
for j = 2:N | |
k = j+offset | |
workspace[k] = c[k-1] / beta | |
beta = b[k] - a[k]*workspace[k] | |
if abs(beta) < 1.e-12 | |
# This should only happen on last element of forward pass for problems | |
# with zero eigenvalue. In that case the algorithmn is still stable. | |
break | |
end | |
x[k] = (d[k]-a[k]*x[k-1])/beta | |
end | |
for j = 1:N-1 | |
k = N-j+offset | |
x[k] = x[k]-workspace[k+1]*x[k+1] | |
end | |
end | |
return | |
end | |
function get_config(kernel) | |
fun = kernel.fun | |
config = launch_configuration(fun) | |
# round up to cover all batches | |
blocks = (N*batchCount + config.threads - 1) ÷ config.threads | |
return (threads=config.threads, blocks=blocks) | |
end | |
@cuda config=get_config kernel(a, b, c, d, x, workspace) | |
CuArrays.unsafe_free!(workspace) | |
return | |
end | |
function batched_tridiag(a, b, c, d, batchCount, batchStride) | |
x = similar(d) | |
batched_tridiag!(a, b, c, d, x, batchCount, batchStride) | |
return x | |
end | |
# parallel tridiagonal solver using cyclic reduction | |
# FIXME: doesn't work on non-pow2 sized inputs | |
function batched_parallel_tridiag!(d_a::CuVector{T}, d_b::CuVector{T}, d_c::CuVector{T}, d_d::CuVector{T}, | |
d_x::CuVector{T}, batchCount::Integer, batchStride::Integer) where {T} | |
N = length(d_d) ÷ batchCount | |
@assert ispow2(N) | |
@assert length(d_a) == length(d_b) == length(d_c) == N*batchCount | |
function kernel(d_a, d_b, d_c, d_d, d_x) | |
thid = threadIdx().x | |
blid = blockIdx().x | |
numThreads = blockDim().x | |
iterations = floor(Int, CUDAnative.log2(Float32(N ÷ 2))) | |
@inbounds begin | |
# load data into shared memory | |
a = @cuDynamicSharedMem(T, (N,)) | |
b = @cuDynamicSharedMem(T, (N,), N*sizeof(T)) | |
c = @cuDynamicSharedMem(T, (N,), N*sizeof(T)*2) | |
d = @cuDynamicSharedMem(T, (N,), N*sizeof(T)*3) | |
x = @cuDynamicSharedMem(T, (N,), N*sizeof(T)*4) | |
a[thid] = d_a[thid + (blid-1) * batchStride] | |
a[thid + blockDim().x] = d_a[thid + blockDim().x + (blid-1) * batchStride] | |
b[thid] = d_b[thid + (blid-1) * batchStride] | |
b[thid + blockDim().x] = d_b[thid + blockDim().x + (blid-1) * batchStride] | |
c[thid] = d_c[thid + (blid-1) * batchStride] | |
c[thid + blockDim().x] = d_c[thid + blockDim().x + (blid-1) * batchStride] | |
d[thid] = d_d[thid + (blid-1) * batchStride] | |
d[thid + blockDim().x] = d_d[thid + blockDim().x + (blid-1) * batchStride] | |
sync_threads() | |
# forward elimination | |
stride = 1 | |
for j = 1:iterations | |
sync_threads() | |
stride *= 2 | |
delta = stride ÷ 2 | |
if threadIdx().x <= numThreads | |
i = stride * (threadIdx().x - 1) + stride | |
iLeft = i - delta | |
iRight = i + delta | |
if iRight > N | |
iRight = N | |
end | |
tmp1 = a[i] / b[iLeft] | |
tmp2 = c[i] / b[iRight] | |
b[i] = b[i] - c[iLeft] * tmp1 - a[iRight] * tmp2 | |
d[i] = d[i] - d[iLeft] * tmp1 - d[iRight] * tmp2 | |
a[i] = -a[iLeft] * tmp1 | |
c[i] = -c[iRight] * tmp2 | |
end | |
numThreads ÷= 2 | |
end | |
if thid <= 2 | |
addr1 = stride; | |
addr2 = 2 * stride; | |
tmp3 = b[addr2]*b[addr1] - c[addr1]*a[addr2] | |
x[addr1] = (b[addr2]*d[addr1]-c[addr1]*d[addr2])/tmp3 | |
x[addr2] = (d[addr2]*b[addr1]-d[addr1]*a[addr2])/tmp3 | |
end | |
# backward substitution | |
numThreads = 2 | |
for j = 1:iterations | |
delta = stride ÷ 2 | |
sync_threads() | |
if thid <= numThreads | |
i = stride * (thid - 1) + stride ÷ 2 | |
if i == delta | |
x[i] = (d[i] - c[i]*x[i+delta]) / b[i] | |
else | |
x[i] = (d[i] - a[i]*x[i-delta] - c[i]*x[i+delta]) / b[i] | |
end | |
end | |
stride ÷= 2 | |
numThreads *= 2 | |
end | |
sync_threads() | |
# write back to global memory | |
d_x[thid + (blid-1) * batchStride] = x[thid] | |
d_x[thid + blockDim().x + (blid-1) * batchStride] = x[thid + blockDim().x] | |
end | |
return | |
end | |
threads = N ÷ 2 | |
shmem = 5 * sizeof(T) * N | |
@cuda blocks=batchCount threads=threads shmem=shmem kernel(d_a, d_b, d_c, d_d, d_x) | |
return | |
end | |
function batched_parallel_tridiag(a, b, c, d, batchCount, batchStride) | |
x = similar(d) | |
batched_parallel_tridiag!(a, b, c, d, x, batchCount, batchStride) | |
return x | |
end | |
# test and benchmark code | |
using Test | |
function test_single() | |
# allocate data | |
N = 256 | |
t = Tridiagonal(rand(N,N)) | |
rhs = rand(N) | |
# solve | |
phi = tridag(t, rhs) | |
@test t * phi ≈ rhs | |
end | |
using CuArrays | |
using CuArrays.CUSPARSE | |
function test_batched() | |
# problem definition (weird sizes to flush out issues) | |
X = 7 | |
Y = 13 | |
Z = 64 | |
T = Float64 | |
# allocate data | |
lower = CuArrays.rand(T, Z, Y, X) | |
lower[1, :, :] = 0 | |
upper = CuArrays.rand(T, Z, Y, X) | |
upper[Z, :, :] = 0 | |
middle = CuArrays.rand(T, Z, Y, X) | |
rhs = CuArrays.rand(T, Z, Y, X) | |
# batched interface uses 1d vectors | |
flat_upper = reshape(upper, X*Y*Z) | |
flat_middle = reshape(middle, X*Y*Z) | |
flat_lower = reshape(lower, X*Y*Z) | |
flat_rhs = reshape(rhs, X*Y*Z) | |
for f in (CUSPARSE.gtsvStridedBatch, batched_tridiag, batched_parallel_tridiag) | |
flat_out = f(flat_lower, flat_middle, flat_upper, flat_rhs, X*Y, Z) | |
out = reshape(flat_out, Z, Y, X) | |
# verify | |
for x in 1:X, y in 1:Y | |
a = Array(lower[:, y, x]) | |
b = Array(middle[:, y, x]) | |
c = Array(upper[:, y, x]) | |
u = Array(out[:, y, x]) | |
d = Array(rhs[:, y, x]) | |
t = Tridiagonal(a[2:end], b, c[1:end-1]) | |
@test t * u ≈ d | |
break | |
end | |
end | |
return | |
end | |
using BenchmarkTools | |
using Statistics | |
function bench_batched() | |
# problem definition | |
X = 256 | |
Y = 256 | |
Z = 256 | |
T = Float32 | |
# allocate data | |
lower = CuArrays.rand(T, Z, Y, X) | |
lower[1, :, :] = 0 | |
upper = CuArrays.rand(T, Z, Y, X) | |
upper[Z, :, :] = 0 | |
middle = CuArrays.rand(T, Z, Y, X) | |
rhs = CuArrays.rand(T, Z, Y, X) | |
# batched interface uses 1d vectors | |
flat_upper = reshape(upper, X*Y*Z) | |
flat_middle = reshape(middle, X*Y*Z) | |
flat_lower = reshape(lower, X*Y*Z) | |
flat_rhs = reshape(rhs, X*Y*Z) | |
suite = BenchmarkGroup() | |
suite["cuSPARSE"] = @benchmarkable begin | |
CuArrays.@sync CUSPARSE.gtsvStridedBatch($flat_lower, $flat_middle, $flat_upper, $flat_rhs, $X*$Y, $Z) | |
end setup=(GC.gc()) | |
suite["serial"] = @benchmarkable begin | |
CuArrays.@sync batched_tridiag($flat_lower, $flat_middle, $flat_upper, $flat_rhs, $X*$Y, $Z) | |
end setup=(GC.gc()) | |
suite["parallel"] = @benchmarkable begin | |
CuArrays.@sync batched_parallel_tridiag($flat_lower, $flat_middle, $flat_upper, $flat_rhs, $X*$Y, $Z) | |
end setup=(GC.gc()) | |
warmup(suite) | |
@show results = run(suite) | |
judge(median(results["parallel"]), median(results["cuSPARSE"])) | |
end |
Hi,
I am wondering if you have a version of the PCR that is threaded (CPU) rather than GPU?
Thank you
@rveltz Not sure if such a version exists but it should be possible to write one using KernelAbstractions.jl: you write one kernel that runs multi-threaded on a CPU and also works on CUDA GPUs.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Just posting your benchmarks from the GPU hackathon for reference: