Last active
February 7, 2024 15:33
-
-
Save wyattowalsh/75a3ea1df349c6a3598839d6f042b9e6 to your computer and use it in GitHub Desktop.
Implementation of Ordinary Least Squares in Python using NumPy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def ols(X, y, fit_intercept=True): | |
"""Ordinary Least Squares (OLS) Regression model with intercept term. | |
Fits an OLS regression model using the closed-form OLS estimator equation. | |
Intercept term is included via design matrix augmentation. | |
Params: | |
X - NumPy matrix, size (N, p), of numerical predictors | |
y - NumPy array, length N, of numerical response | |
fit_intercept - Boolean indicating whether to include an intercept term | |
Returns: | |
NumPy array, length p + 1, of fitted model coefficients | |
""" | |
m, n = np.shape(X) | |
if fit_intercept: | |
X = np.hstack((np.ones((m, 1)), X)) | |
return np.linalg.solve(np.dot(X.T, X), np.dot(X.T, y)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment