Skip to content

Instantly share code, notes, and snippets.

@oxinabox
Last active May 23, 2021 12:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oxinabox/ad054535b8f84cf060d3ac35af77c64f to your computer and use it in GitHub Desktop.
Save oxinabox/ad054535b8f84cf060d3ac35af77c64f to your computer and use it in GitHub Desktop.
Why are so many implementations of the Thomas algorithm wrong?
using LinearAlgebra
# Wikipedia non-preserving version (transcribed from VB)
# https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm
# this one is wrong (found this in use in the wild 😒)
function thomas_algorithm!(a, b, c, r, ::Val{1})
n = length(b)
for i in 2:(n-1)
m = a[i]/b[i-1];
b[i] = b[i] - m * c[i - 1];
r[i] = r[i] - m*r[i-1];
end
x = similar(b)
x[end] = r[end]/b[end];
for i in (n-1):-1:1
x[i] = (r[i] - c[i] * x[i+1]) / b[i]
end
return x
end
# wikipdia out of place version
# this one is right
function thomas_algorithm!(a, b, c, d, ::Val{2})
n = length(b)
dp = similar(d)
cp = similar(c)
dp[1] = d[1]/b[1]
cp[1] = c[1]/b[1]
for i in 2:n
r = 1/(b[i] - a[i]*cp[i-1])
dp[i] = r*(d[i] - a[i]*dp[i-1])
cp[i] = r * c[i]
end
x = similar(d)
x[end] = dp[end]
for i in (n-1):-1:1
x[i] = dp[i] - cp[i]*x[i+1]
end
return x
end
# Algorithm 1 from https://people.maths.ox.ac.uk/gilesm/files/toms_16b.pdf
# this one is also wrong
function thomas_algorithm!(a, b, c, d, ::Val{3})
n = length(b)
dp = similar(d)
cp = similar(c)
dp[1] = d[1]/b[1]
cp[1] = c[1]/b[1]
for i in 2:n
r = 1/(b[i] - a[i]*c[i-1])
dp[i] = r*(d[i] - a[i]*d[i-1])
cp[i] = r * c[i]
end
for i in (n-1):-1:1
d[i] = dp[i] - cp[i]*d[i+1]
end
return d
end
# from torchcde
# https://github.com/patrick-kidger/torchcde/blob/d3ebdd554f138a07832e31cacca7bc0944d2004e/torchcde/misc.py#L13
# Correct as long as not padded
function thomas_algorithm!(A_lower, A_diagonal, A_upper, b, ::Val{4})
channels = length(A_diagonal)
new_b = similar(b)
new_A_diagonal = similar(A_diagonal)
outs = similar(A_diagonal)
new_b[1] = b[1]
new_A_diagonal[1] = A_diagonal[1]
for i in 2:channels
w = A_lower[i-1]/new_A_diagonal[i-1];
new_A_diagonal[i] = A_diagonal[i] - w * A_upper[i - 1];
new_b[i] = b[i] - w*new_b[i-1];
end
outs[end] = new_b[end]/new_A_diagonal[end];
for i in (channels-1):-1:1
outs[i] = (new_b[i] - A_upper[i] * outs[i+1]) / new_A_diagonal[i]
end
return outs
end
#################
# No padding
function thomas_algorithm(lhs::Tridiagonal, r, ver::Val{4})
a = diag(lhs, -1)
b = diag(lhs)
c = diag(lhs, 1)
d = copy(r)
return thomas_algorithm!(a, b, c, d, ver)
end
# padding
function thomas_algorithm(lhs::Tridiagonal, r, ver::Union{Val{1}, Val{2}, Val{3}})
a = [0; diag(lhs, -1)]
b = diag(lhs)
c = [diag(lhs, 1); 0]
d = copy(r)
return thomas_algorithm!(a, b, c, d, ver)
end
######################
# Experiment 2x2
lhs = Tridiagonal([2.0 1.0; 2.0 7.0])
rhs = [1.800000007527517, -7.400000108059436]
@show lhs\rhs
@show thomas_algorithm(lhs, rhs, Val(1))
@show thomas_algorithm(lhs, rhs, Val(2))
@show thomas_algorithm(lhs, rhs, Val(3))
@show thomas_algorithm(lhs, rhs, Val(4))
@show lhs*(lhs\rhs) β‰ˆ rhs # true
@show lhs*(thomas_algorithm(lhs, rhs, Val(1))) β‰ˆ rhs # false
#@show lhs*(thomas_algorithm(lhs, rhs, Val(2))) β‰ˆ rhs # true
#@show lhs*(thomas_algorithm(lhs, rhs, Val(3))) β‰ˆ rhs # false
@show lhs*(thomas_algorithm(lhs, rhs, Val(4))) β‰ˆ rhs
#####################
# Experiment 3x3
lhs = Tridiagonal([1., 2], [10., 20, 30], [1., 2])
rhs = [11., 12, 13]
lhs\rhs
lhs*(lhs\rhs)
@show lhs\rhs
@show thomas_algorithm(lhs, rhs, Val(1))
@show thomas_algorithm(lhs, rhs, Val(2))
@show thomas_algorithm(lhs, rhs, Val(3))
@show thomas_algorithm(lhs, rhs, Val(4))
@show lhs*(lhs\rhs) β‰ˆ rhs # true
@show lhs*(thomas_algorithm(lhs, rhs, Val(1))) β‰ˆ rhs # false
@show lhs*(thomas_algorithm(lhs, rhs, Val(2))) β‰ˆ rhs # true
@show lhs*(thomas_algorithm(lhs, rhs, Val(3))) β‰ˆ rhs # false
@show lhs*(thomas_algorithm(lhs, rhs, Val(4))) β‰ˆ rhs # true
@mschauer
Copy link

Did you fix the Wikipedia ones?

@oxinabox
Copy link
Author

oxinabox commented May 22, 2021

Not yet. I am waiting to be told that I implemented them wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment