Skip to content

Instantly share code, notes, and snippets.

@wyattowalsh
Last active February 7, 2024 15:33
Show Gist options
  • Save wyattowalsh/75a3ea1df349c6a3598839d6f042b9e6 to your computer and use it in GitHub Desktop.
Save wyattowalsh/75a3ea1df349c6a3598839d6f042b9e6 to your computer and use it in GitHub Desktop.
Implementation of Ordinary Least Squares in Python using NumPy
def ols(X, y, fit_intercept=True):
"""Ordinary Least Squares (OLS) Regression model with intercept term.
Fits an OLS regression model using the closed-form OLS estimator equation.
Intercept term is included via design matrix augmentation.
Params:
X - NumPy matrix, size (N, p), of numerical predictors
y - NumPy array, length N, of numerical response
fit_intercept - Boolean indicating whether to include an intercept term
Returns:
NumPy array, length p + 1, of fitted model coefficients
"""
m, n = np.shape(X)
if fit_intercept:
X = np.hstack((np.ones((m, 1)), X))
return np.linalg.solve(np.dot(X.T, X), np.dot(X.T, y))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment