Skip to content

Instantly share code, notes, and snippets.

@aldro61
Last active August 9, 2021 15:20
Show Gist options
  • Save aldro61/5889795 to your computer and use it in GitHub Desktop.
Save aldro61/5889795 to your computer and use it in GitHub Desktop.
A linear least squares solver for python. This function outperforms numpy.linalg.lstsq in terms of computation time and memory.
# Copyright (c) 2013 Alexandre Drouin. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# If you happen to meet one of the copyright holders in a bar you are obligated
# to buy them one pint of beer.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from warnings import warn
import numpy as np
from scipy.linalg.fblas import dgemm
def linear_least_squares(a, b, residuals=False):
"""
Return the least-squares solution to a linear matrix equation.
Solves the equation `a x = b` by computing a vector `x` that
minimizes the Euclidean 2-norm `|| b - a x ||^2`. The equation may
be under-, well-, or over- determined (i.e., the number of
linearly independent rows of `a` can be less than, equal to, or
greater than its number of linearly independent columns). If `a`
is square and of full rank, then `x` (but for round-off error) is
the "exact" solution of the equation.
Parameters
----------
a : (M, N) array_like
"Coefficient" matrix.
b : (M,) array_like
Ordinate or "dependent variable" values.
residuals : bool
Compute the residuals associated with the least-squares solution
Returns
-------
x : (M,) ndarray
Least-squares solution. The shape of `x` depends on the shape of
`b`.
residuals : int (Optional)
Sums of residuals; squared Euclidean 2-norm for each column in
``b - a*x``.
"""
if type(a) != np.ndarray or not a.flags['C_CONTIGUOUS']:
warn('Matrix a is not a C-contiguous numpy array. The solver will create a copy, which will result' + \
' in increased memory usage.')
a = np.asarray(a, order='c')
i = dgemm(alpha=1.0, a=a.T, b=a.T, trans_b=True)
x = np.linalg.solve(i, dgemm(alpha=1.0, a=a.T, b=b)).flatten()
if residuals:
return x, np.linalg.norm(np.dot(a, x) - b)
else:
return x
if __name__ == "__main__":
x = np.array([0, 1, 2, 3])
y = np.array([-1, 0.2, 0.9, 2.1])
A = np.vstack([x, np.ones(len(x))]).T
A = np.asarray(A, order='c')
m, c = linear_least_squares(A, y)
print m, c
import matplotlib.pyplot as plt
plt.plot(x, y, 'o', label='Original data', markersize=10)
plt.plot(x, m * x + c, 'r', label='Fitted line')
plt.legend()
plt.show()
@thoppe
Copy link

thoppe commented Jan 13, 2014

According to the docs, your import statement looks incorrect. See:

http://docs.scipy.org/doc/scipy-dev/reference/generated/scipy.linalg.blas.dgemm.html

I had to change the dgemm import to this to get it to work:

from scipy.linalg.blas import dgemm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment