Created
September 29, 2019 00:46
-
-
Save staticor/383ebf10bceabbe9b8e46da7e4927dec to your computer and use it in GitHub Desktop.
Least Square Method in Linear Regression.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Linear Least Squares | |
def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False, | |
check_finite=True, lapack_driver=None): | |
""" | |
Compute least-squares solution to equation Ax = b. | |
Compute a vector x such that the 2-norm ``|b - A x|`` is minimized. | |
Parameters | |
---------- | |
a : (M, N) array_like | |
Left hand side array | |
b : (M,) or (M, K) array_like | |
Right hand side array | |
cond : float, optional | |
Cutoff for 'small' singular values; used to determine effective | |
rank of a. Singular values smaller than | |
``rcond * largest_singular_value`` are considered zero. | |
overwrite_a : bool, optional | |
Discard data in `a` (may enhance performance). Default is False. | |
overwrite_b : bool, optional | |
Discard data in `b` (may enhance performance). Default is False. | |
check_finite : bool, optional | |
Whether to check that the input matrices contain only finite numbers. | |
Disabling may give a performance gain, but may result in problems | |
(crashes, non-termination) if the inputs do contain infinities or NaNs. | |
lapack_driver : str, optional | |
Which LAPACK driver is used to solve the least-squares problem. | |
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default | |
(``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly | |
faster on many problems. ``'gelss'`` was used historically. It is | |
generally slow but uses less memory. | |
.. versionadded:: 0.17.0 | |
Returns | |
------- | |
x : (N,) or (N, K) ndarray | |
Least-squares solution. Return shape matches shape of `b`. | |
residues : (K,) ndarray or float | |
Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and | |
``ndim(A) == n`` (returns a scalar if b is 1-D). Otherwise a | |
(0,)-shaped array is returned. | |
rank : int | |
Effective rank of `a`. | |
s : (min(M, N),) ndarray or None | |
Singular values of `a`. The condition number of a is | |
``abs(s[0] / s[-1])``. | |
Raises | |
------ | |
LinAlgError | |
If computation does not converge. | |
ValueError | |
When parameters are not compatible. | |
See Also | |
-------- | |
scipy.optimize.nnls : linear least squares with non-negativity constraint | |
Notes | |
----- | |
When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped | |
array and `s` is always ``None``. | |
Examples | |
-------- | |
>>> from scipy.linalg import lstsq | |
>>> import matplotlib.pyplot as plt | |
Suppose we have the following data: | |
>>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5]) | |
>>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6]) | |
We want to fit a quadratic polynomial of the form ``y = a + b*x**2`` | |
to this data. We first form the "design matrix" M, with a constant | |
column of 1s and a column containing ``x**2``: | |
>>> M = x[:, np.newaxis]**[0, 2] | |
>>> M | |
array([[ 1. , 1. ], | |
[ 1. , 6.25], | |
[ 1. , 12.25], | |
[ 1. , 16. ], | |
[ 1. , 25. ], | |
[ 1. , 49. ], | |
[ 1. , 72.25]]) | |
We want to find the least-squares solution to ``M.dot(p) = y``, | |
where ``p`` is a vector with length 2 that holds the parameters | |
``a`` and ``b``. | |
>>> p, res, rnk, s = lstsq(M, y) | |
>>> p | |
array([ 0.20925829, 0.12013861]) | |
Plot the data and the fitted curve. | |
>>> plt.plot(x, y, 'o', label='data') | |
>>> xx = np.linspace(0, 9, 101) | |
>>> yy = p[0] + p[1]*xx**2 | |
>>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$') | |
>>> plt.xlabel('x') | |
>>> plt.ylabel('y') | |
>>> plt.legend(framealpha=1, shadow=True) | |
>>> plt.grid(alpha=0.25) | |
>>> plt.show() | |
""" | |
a1 = _asarray_validated(a, check_finite=check_finite) | |
b1 = _asarray_validated(b, check_finite=check_finite) | |
if len(a1.shape) != 2: | |
raise ValueError('Input array a should be 2-D') | |
m, n = a1.shape | |
if len(b1.shape) == 2: | |
nrhs = b1.shape[1] | |
else: | |
nrhs = 1 | |
if m != b1.shape[0]: | |
raise ValueError('Shape mismatch: a and b should have the same number' | |
' of rows ({} != {}).'.format(m, b1.shape[0])) | |
if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK | |
x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1)) | |
if n == 0: | |
residues = np.linalg.norm(b1, axis=0)**2 | |
else: | |
residues = np.empty((0,)) | |
return x, residues, 0, np.empty((0,)) | |
driver = lapack_driver | |
if driver is None: | |
driver = lstsq.default_lapack_driver | |
if driver not in ('gelsd', 'gelsy', 'gelss'): | |
raise ValueError('LAPACK driver "%s" is not found' % driver) | |
lapack_func, lapack_lwork = get_lapack_funcs((driver, | |
'%s_lwork' % driver), | |
(a1, b1)) | |
real_data = True if (lapack_func.dtype.kind == 'f') else False | |
if m < n: | |
# need to extend b matrix as it will be filled with | |
# a larger solution matrix | |
if len(b1.shape) == 2: | |
b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype) | |
b2[:m, :] = b1 | |
else: | |
b2 = np.zeros(n, dtype=lapack_func.dtype) | |
b2[:m] = b1 | |
b1 = b2 | |
overwrite_a = overwrite_a or _datacopied(a1, a) | |
overwrite_b = overwrite_b or _datacopied(b1, b) | |
if cond is None: | |
cond = np.finfo(lapack_func.dtype).eps | |
if driver in ('gelss', 'gelsd'): | |
if driver == 'gelss': | |
lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond) | |
v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork, | |
overwrite_a=overwrite_a, | |
overwrite_b=overwrite_b) | |
elif driver == 'gelsd': | |
if real_data: | |
lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond) | |
x, s, rank, info = lapack_func(a1, b1, lwork, | |
iwork, cond, False, False) | |
else: # complex data | |
lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n, | |
nrhs, cond) | |
x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork, | |
cond, False, False) | |
if info > 0: | |
raise LinAlgError("SVD did not converge in Linear Least Squares") | |
if info < 0: | |
raise ValueError('illegal value in %d-th argument of internal %s' | |
% (-info, lapack_driver)) | |
resids = np.asarray([], dtype=x.dtype) | |
if m > n: | |
x1 = x[:n] | |
if rank == n: | |
resids = np.sum(np.abs(x[n:])**2, axis=0) | |
x = x1 | |
return x, resids, rank, s | |
elif driver == 'gelsy': | |
lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond) | |
jptv = np.zeros((a1.shape[1], 1), dtype=np.int32) | |
v, x, j, rank, info = lapack_func(a1, b1, jptv, cond, | |
lwork, False, False) | |
if info < 0: | |
raise ValueError("illegal value in %d-th argument of internal " | |
"gelsy" % -info) | |
if m > n: | |
x1 = x[:n] | |
x = x1 | |
return x, np.array([], x.dtype), rank, None | |
lstsq.default_lapack_driver = 'gelsd' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment