Skip to content

Instantly share code, notes, and snippets.

@staticor
Created September 29, 2019 00:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save staticor/383ebf10bceabbe9b8e46da7e4927dec to your computer and use it in GitHub Desktop.
Save staticor/383ebf10bceabbe9b8e46da7e4927dec to your computer and use it in GitHub Desktop.
Least Square Method in Linear Regression.
# Linear Least Squares
def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
check_finite=True, lapack_driver=None):
"""
Compute least-squares solution to equation Ax = b.
Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
Parameters
----------
a : (M, N) array_like
Left hand side array
b : (M,) or (M, K) array_like
Right hand side array
cond : float, optional
Cutoff for 'small' singular values; used to determine effective
rank of a. Singular values smaller than
``rcond * largest_singular_value`` are considered zero.
overwrite_a : bool, optional
Discard data in `a` (may enhance performance). Default is False.
overwrite_b : bool, optional
Discard data in `b` (may enhance performance). Default is False.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
lapack_driver : str, optional
Which LAPACK driver is used to solve the least-squares problem.
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
(``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
faster on many problems. ``'gelss'`` was used historically. It is
generally slow but uses less memory.
.. versionadded:: 0.17.0
Returns
-------
x : (N,) or (N, K) ndarray
Least-squares solution. Return shape matches shape of `b`.
residues : (K,) ndarray or float
Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
``ndim(A) == n`` (returns a scalar if b is 1-D). Otherwise a
(0,)-shaped array is returned.
rank : int
Effective rank of `a`.
s : (min(M, N),) ndarray or None
Singular values of `a`. The condition number of a is
``abs(s[0] / s[-1])``.
Raises
------
LinAlgError
If computation does not converge.
ValueError
When parameters are not compatible.
See Also
--------
scipy.optimize.nnls : linear least squares with non-negativity constraint
Notes
-----
When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
array and `s` is always ``None``.
Examples
--------
>>> from scipy.linalg import lstsq
>>> import matplotlib.pyplot as plt
Suppose we have the following data:
>>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
>>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
to this data. We first form the "design matrix" M, with a constant
column of 1s and a column containing ``x**2``:
>>> M = x[:, np.newaxis]**[0, 2]
>>> M
array([[ 1. , 1. ],
[ 1. , 6.25],
[ 1. , 12.25],
[ 1. , 16. ],
[ 1. , 25. ],
[ 1. , 49. ],
[ 1. , 72.25]])
We want to find the least-squares solution to ``M.dot(p) = y``,
where ``p`` is a vector with length 2 that holds the parameters
``a`` and ``b``.
>>> p, res, rnk, s = lstsq(M, y)
>>> p
array([ 0.20925829, 0.12013861])
Plot the data and the fitted curve.
>>> plt.plot(x, y, 'o', label='data')
>>> xx = np.linspace(0, 9, 101)
>>> yy = p[0] + p[1]*xx**2
>>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
>>> plt.xlabel('x')
>>> plt.ylabel('y')
>>> plt.legend(framealpha=1, shadow=True)
>>> plt.grid(alpha=0.25)
>>> plt.show()
"""
a1 = _asarray_validated(a, check_finite=check_finite)
b1 = _asarray_validated(b, check_finite=check_finite)
if len(a1.shape) != 2:
raise ValueError('Input array a should be 2-D')
m, n = a1.shape
if len(b1.shape) == 2:
nrhs = b1.shape[1]
else:
nrhs = 1
if m != b1.shape[0]:
raise ValueError('Shape mismatch: a and b should have the same number'
' of rows ({} != {}).'.format(m, b1.shape[0]))
if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
if n == 0:
residues = np.linalg.norm(b1, axis=0)**2
else:
residues = np.empty((0,))
return x, residues, 0, np.empty((0,))
driver = lapack_driver
if driver is None:
driver = lstsq.default_lapack_driver
if driver not in ('gelsd', 'gelsy', 'gelss'):
raise ValueError('LAPACK driver "%s" is not found' % driver)
lapack_func, lapack_lwork = get_lapack_funcs((driver,
'%s_lwork' % driver),
(a1, b1))
real_data = True if (lapack_func.dtype.kind == 'f') else False
if m < n:
# need to extend b matrix as it will be filled with
# a larger solution matrix
if len(b1.shape) == 2:
b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
b2[:m, :] = b1
else:
b2 = np.zeros(n, dtype=lapack_func.dtype)
b2[:m] = b1
b1 = b2
overwrite_a = overwrite_a or _datacopied(a1, a)
overwrite_b = overwrite_b or _datacopied(b1, b)
if cond is None:
cond = np.finfo(lapack_func.dtype).eps
if driver in ('gelss', 'gelsd'):
if driver == 'gelss':
lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
overwrite_a=overwrite_a,
overwrite_b=overwrite_b)
elif driver == 'gelsd':
if real_data:
lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
x, s, rank, info = lapack_func(a1, b1, lwork,
iwork, cond, False, False)
else: # complex data
lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
nrhs, cond)
x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
cond, False, False)
if info > 0:
raise LinAlgError("SVD did not converge in Linear Least Squares")
if info < 0:
raise ValueError('illegal value in %d-th argument of internal %s'
% (-info, lapack_driver))
resids = np.asarray([], dtype=x.dtype)
if m > n:
x1 = x[:n]
if rank == n:
resids = np.sum(np.abs(x[n:])**2, axis=0)
x = x1
return x, resids, rank, s
elif driver == 'gelsy':
lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
lwork, False, False)
if info < 0:
raise ValueError("illegal value in %d-th argument of internal "
"gelsy" % -info)
if m > n:
x1 = x[:n]
x = x1
return x, np.array([], x.dtype), rank, None
lstsq.default_lapack_driver = 'gelsd'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment