Skip to content

Instantly share code, notes, and snippets.

@Joshuaalbert
Created April 22, 2017 14:18
Show Gist options
  • Save Joshuaalbert/c394c44c4ffec6ea48780cf63294126d to your computer and use it in GitHub Desktop.
Save Joshuaalbert/c394c44c4ffec6ea48780cf63294126d to your computer and use it in GitHub Desktop.
test choSolver against lapack
#from scipy.linalg import cho_solve
from scipy.linalg.lapack import dpotrs
import numpy as np
def choBackSubstitution(L,y,lower=True,modify=False):
if not modify:
x = np.copy(y)
else:
x = y
if lower:
i = 0
while i < L.shape[0]:
x[i] /= L[i,i]
x[i+1:] -= L[i+1:,i]*x[i]
i += 1
else:
i = L.shape[0] - 1
while i >= 0:
x[i] /= L[i,i]
x[:i] -= L[:i,i]*x[i]
i -= 1
return x
def choSolve(L,b,modify=False):
return choBackSubstitution(L.T,choBackSubstitution(L,b,True,modify),False,modify)
N = 5000
y = np.random.uniform(size=N)
a = np.random.uniform(size=[N,N])
a = a.T.dot(a)
L = np.linalg.cholesky(a)
%load_ext line_profiler
%lprun -f choBackSubstitution choSolve(L,y,False)
#with y vec mod (no copy)
%timeit -n 10 choSolve(L,y,False)
#built in
#%timeit cho_solve((L,True),y)
%timeit -n 10 dpotrs(L,y,1,0)
#x1 = cho_solve((L,True),y)
x1 = dpotrs(L,y,1,0)
x2 = choSolve(L,y,False)
#x1 = dpotrs(L,y,1,1)
print("same:",np.alltrue(np.isclose(x1[0],x2)))
times1 = []
times2 = []
Ns = 10**np.linspace(1,4,20)
from time import clock
for N in Ns:
N = int(N)
y = np.random.uniform(size=N)
a = np.random.uniform(size=[N,N])
a = a.T.dot(a)
L = np.linalg.cholesky(a)
t1 = clock()
#x1 = cho_solve((L,True),y)
x1 = dpotrs(L,y,1,0)
times1.append(clock()-t1)
t1 = clock()
x2 = choSolve(L,y,False)
times2.append(clock()-t1)
import pylab as plt
plt.plot(Ns,times1,label='scipy.linalg.cho_solve')
plt.plot(Ns,times2,label='my choSolve')
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment