-
-
Save iamtimdavis/767d122d365fde9f7d9185ae191e08d1 to your computer and use it in GitHub Desktop.
linear_regressor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_model( | |
learning_rate, | |
steps, | |
batch_size, | |
training_examples, | |
training_targets, | |
validation_examples, | |
validation_targets): | |
"""Trains a linear regression model of multiple features. | |
In addition to training, this function also prints training progress information, | |
as well as a plot of the training and validation loss over time. | |
Args: | |
learning_rate: A `float`, the learning rate. | |
steps: A non-zero `int`, the total number of training steps. A training step | |
consists of a forward and backward pass using a single batch. | |
batch_size: A non-zero `int`, the batch size. | |
training_examples: A `DataFrame` containing one or more columns from | |
`california_housing_dataframe` to use as input features for training. | |
training_targets: A `DataFrame` containing exactly one column from | |
`california_housing_dataframe` to use as target for training. | |
validation_examples: A `DataFrame` containing one or more columns from | |
`california_housing_dataframe` to use as input features for validation. | |
validation_targets: A `DataFrame` containing exactly one column from | |
`california_housing_dataframe` to use as target for validation. | |
Returns: | |
A `LinearRegressor` object trained on the training data. | |
""" | |
periods = 10 | |
steps_per_period = steps / periods | |
# Create a linear regressor object. | |
my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) | |
#Utilize gradient clipping to avoid vanishing or exploding gradients | |
my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0) | |
#Apply the linear regressor to the dataset | |
linear_regressor = tf.estimator.LinearRegressor( | |
feature_columns=construct_feature_columns(training_examples), | |
optimizer=my_optimizer | |
) | |
# Create input functions. | |
training_input_fn = lambda: my_input_fn( | |
training_examples, | |
training_targets["median_house_value"], | |
batch_size=batch_size) | |
predict_training_input_fn = lambda: my_input_fn( | |
training_examples, | |
training_targets["median_house_value"], | |
num_epochs=1, | |
shuffle=False) | |
predict_validation_input_fn = lambda: my_input_fn( | |
validation_examples, validation_targets["median_house_value"], | |
num_epochs=1, | |
shuffle=False) | |
# Train the model, but do so inside a loop so that we can periodically assess | |
# loss metrics. | |
print "Training model..." | |
print "RMSE (on training data):" | |
training_rmse = [] | |
validation_rmse = [] | |
for period in range (0, periods): | |
# Train the model, starting from the prior state. | |
linear_regressor.train( | |
input_fn=training_input_fn, | |
steps=steps_per_period, | |
) | |
# Take a break and compute predictions. | |
training_predictions = linear_regressor.predict(input_fn=predict_training_input_fn) | |
training_predictions = np.array([item['predictions'][0] for item in training_predictions]) | |
validation_predictions = linear_regressor.predict(input_fn=predict_validation_input_fn) | |
validation_predictions = np.array([item['predictions'][0] for item in validation_predictions]) | |
# Compute training and validation loss. | |
training_root_mean_squared_error = math.sqrt( | |
metrics.mean_squared_error(training_predictions, training_targets)) | |
validation_root_mean_squared_error = math.sqrt( | |
metrics.mean_squared_error(validation_predictions, validation_targets)) | |
# Occasionally print the current loss. | |
print " period %02d : %0.2f" % (period, training_root_mean_squared_error) | |
# Add the loss metrics from this period to our list. | |
training_rmse.append(training_root_mean_squared_error) | |
validation_rmse.append(validation_root_mean_squared_error) | |
print "Model training finished." | |
# Output a graph of loss metrics over periods. | |
plt.ylabel("RMSE") | |
plt.xlabel("Periods") | |
plt.title("Root Mean Squared Error vs. Periods") | |
plt.tight_layout() | |
plt.plot(training_rmse, label="training") | |
plt.plot(validation_rmse, label="validation") | |
plt.legend() | |
return linear_regressor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment