Skip to content

Instantly share code, notes, and snippets.

@stichbury
Created April 26, 2021 10:30
Show Gist options
  • Save stichbury/e7bc663368b92d17fd72f82115bde16c to your computer and use it in GitHub Desktop.
Save stichbury/e7bc663368b92d17fd72f82115bde16c to your computer and use it in GitHub Desktop.
def split_data(data: pd.DataFrame, parameters: Dict) -> Tuple:
"""Splits data into features and targets training and test sets.
Args:
data: Data containing features and target.
parameters: Parameters defined in parameters.yml.
Returns:
Split data.
"""
X = data[parameters["features"]]
y = data["price"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=parameters["test_size"], random_state=parameters["random_state"]
)
return X_train, X_test, y_train, y_test
def train_model(X_train: pd.DataFrame, y_train: pd.Series) -> LinearRegression:
"""Trains the linear regression model.
Args:
X_train: Training data of independent features.
y_train: Training data for price.
Returns:
Trained model.
"""
regressor = LinearRegression()
regressor.fit(X_train, y_train)
return regressor
def evaluate_model(
regressor: LinearRegression, X_test: pd.DataFrame, y_test: pd.Series
):
"""Calculates and logs the coefficient of determination.
Args:
regressor: Trained model.
X_test: Testing data of independent features.
y_test: Testing data for price.
"""
y_pred = regressor.predict(X_test)
score = r2_score(y_test, y_pred)
logger = logging.getLogger(__name__)
logger.info("Model has a coefficient R^2 of %.3f on test data.", score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment