Skip to content

Instantly share code, notes, and snippets.

@ahwillia
Last active August 29, 2019 18:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ahwillia/97682d2600fa58cb758b090d1b328a4d to your computer and use it in GitHub Desktop.
Save ahwillia/97682d2600fa58cb758b090d1b328a4d to your computer and use it in GitHub Desktop.
Fast solver for a symmetric tridiagonal circulant linear system in Python.
import numpy as np
from scipy.linalg import solve_circulant, circulant
from numpy.testing import assert_array_almost_equal
import numba
@numba.jit(nopython=True, cache=True)
def rojo_method(c, a, f, x, z):
"""
Solves symmetric, tridiagonal circulant system, assuming diagonal
dominance. Algorithm and notation described in Rojo (1990).
Parameters
----------
c : float
Diagonal elements.
a : float
Off-diagonal elements. Should satisfy abs(c) > 2 * abs(a).
f : ndarray
Right-hand side.
x : ndarray
Vector holding solution.
z : ndarray
Vector storing intermediate computations
Reference
---------
Rojo O (1990). A new method for solving symmetric circulant
tridiagonal systems of linear equations. Computers Math Applic.
20(12):61-67.
"""
N = f.size
for i in range(N):
z[i] = -f[i] / a
lam = -c / a
if lam > 0:
mu = 0.5 * lam + np.sqrt(0.25 * (lam ** 2) - 1)
else:
mu = 0.5 * lam - np.sqrt(0.25 * (lam ** 2) - 1)
z[0] = z[0] + (z[-1] / lam)
for i in range(1, N - 2):
z[i] = z[i] + (z[i - 1] / mu)
z[-2] = z[-2] + (z[-1] / lam) + (z[-3] / mu)
z[-2] = z[-2] / mu
for i in range(N - 2):
z[-3 - i] = (z[-3 - i] + z[-2 - i]) / mu
musm1 = ((mu ** 2) - 1)
d = (1 - (mu ** -N)) * musm1 * mu
mu1 = mu ** (1 - N)
mu2 = mu
mu3 = mu ** (3 - N)
for i in range(N - 1):
x[i] = z[i] + (musm1 * mu1 * z[0] + (mu2 + mu3) * z[-2]) / d
mu1 *= mu
mu2 /= mu
mu3 *= mu
x[-1] = (z[-1] + x[0] + x[-2]) / lam
return x
# Simple test case.
if __name__ == "__main__":
# Create random example.
N = 101
c = np.random.uniform(0.51, 1.0)
a = 1 - c
f = np.random.randn(N) * .1
# Compute solution with scipy
coeffs = np.zeros(N)
coeffs[0] = c
coeffs[1] = a
coeffs[-1] = a
x1 = solve_circulant(coeffs, f)
x2 = np.linalg.solve(circulant(coeffs), f)
# Compute solution with Rojo method.
x = np.full(N, np.nan)
z = np.full(N, np.nan)
rojo_method(c, a, f, x, z)
# Check consistency.
assert_array_almost_equal(x1, x2)
assert_array_almost_equal(x, x1)
assert_array_almost_equal(x, x2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment