Created
January 30, 2021 00:33
-
-
Save wyattowalsh/11b165dcf426d6aba44ba7a8bf16836d to your computer and use it in GitHub Desktop.
Implementations of Ordinary Least Squares (OLS) in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def ols_numpy(X, y, fit_intercept=True): | |
""" | |
Fits an OLS regression model using the closed-form OLS estimator equation. | |
Intercept term is included via design matrix augmentation. | |
Params: | |
X - NumPy matrix, size (N, p), of numerical predictors | |
y - NumPy array, length N, of numerical response | |
fit_intercept - Boolean indicating whether to include an intercept term | |
Returns: | |
NumPy array, length p + 1, of fitted model coefficients | |
Note: Solving for the OLS estimator using the matrix inverse does not scale well, | |
thus the NumPy function solve, which employs the LAPACK _gesv routine, is used to find the least-squares solution. | |
This function solves the equation in the case where A is square and full-rank (linearly independent columns). | |
However, in the case that A is not full-rank, then the function lstsq should be used, | |
which utilizes the xGELSD routine and thus finds the singular value decomposition of A. | |
""" | |
m, n = np.shape(X) | |
if fit_intercept: | |
X = np.hstack((np.ones((m, 1)), X)) | |
return np.linalg.solve(np.dot(X.T, X), np.dot(X.T, y)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment