Skip to content

Instantly share code, notes, and snippets.

@sharmaeklavya2
Created November 2, 2019 14:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sharmaeklavya2/65cd82bc99453a24e3e8a31db67515e5 to your computer and use it in GitHub Desktop.
Save sharmaeklavya2/65cd82bc99453a24e3e8a31db67515e5 to your computer and use it in GitHub Desktop.
Solving linear equations using conjugate gradient descent with exact arithmetic
import sys
import argparse
from sympy import Matrix, ImmutableMatrix, Rational
def to_print_type(a):
return a
def cga(Q, b, x, n, err_thresh=0, opt=None):
zero_vec = ImmutableMatrix([0] * x.shape[0])
ehist = []
if opt is None:
print('computing opt', file=sys.stderr)
opt = - (b.T * (Q**(-1)) * b)[0, 0] / 2
y = (x.T @ Q @ x)[0, 0] / 2 - b.dot(x)
print('e0:', (y - opt).evalf(), file=sys.stderr)
ehist.append(float((y - opt).evalf()))
g = Q @ x - b
# print('g0:', g.evalf(), file=sys.stderr)
if g == zero_vec:
return (list(x), ehist)
u = -g
# print('u0:', u, file=sys.stderr)
for i in range(n):
alpha = - g.dot(u) / (u.T @ Q @ u)[0, 0]
# print('alpha{}: {}'.format(i, alpha), file=sys.stderr)
x = x + alpha * u
print('x{}: {}'.format(i + 1, ['{:.3f}'.format(xi) for xi in x.evalf()]), file=sys.stderr)
y = (x.T @ Q @ x)[0, 0] / 2 - b.dot(x)
print('e{}: {}'.format(i + 1, (y - opt).evalf()), file=sys.stderr)
ehist.append(float((y - opt).evalf()))
g = Q @ x - b
# print('g{}: {}'.format(i + 1, g.evalf()), file=sys.stderr)
if g == zero_vec or y - opt < err_thresh:
return (list(x), ehist)
beta = (u.T @ Q @ g)[0, 0] / (u.T @ Q @ u)[0, 0]
# print('beta{}: {}'.format(i, beta), file=sys.stderr)
u = -g + beta * u
# print('u{}: {}'.format(i + 1, u), file=sys.stderr)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('n', type=int)
args = parser.parse_args()
n = args.n
A = Matrix([[Rational(1, i + j + 1) for j in range(n)] for i in range(n)])
b = Matrix([1] * n)
# Ab = A.row_join(b)
# pprint(Ab)
# x = linsolve(Ab)
# print('expected:', x)
Q = 2 * A.T @ A
d = (A + A.T) @ b
x_0 = Matrix([0] * n)
x, ehist = cga(Q, d, x_0, n + 5, opt=-b.dot(b))
print('cga output:', x)
print('errors:', ehist)
# print('expected:', x)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment