linear_regressor.py
def train_model( | |
learning_rate, | |
steps, | |
batch_size, | |
training_examples, | |
training_targets, | |
validation_examples, | |
validation_targets): | |
"""Trains a linear regression model of multiple features. | |
In addition to training, this function also prints training progress information, | |
as well as a plot of the training and validation loss over time. | |
Args: | |
learning_rate: A `float`, the learning rate. | |
steps: A non-zero `int`, the total number of training steps. A training step | |
consists of a forward and backward pass using a single batch. | |
batch_size: A non-zero `int`, the batch size. | |
training_examples: A `DataFrame` containing one or more columns from | |
`california_housing_dataframe` to use as input features for training. | |
training_targets: A `DataFrame` containing exactly one column from | |
`california_housing_dataframe` to use as target for training. | |
validation_examples: A `DataFrame` containing one or more columns from | |
`california_housing_dataframe` to use as input features for validation. | |
validation_targets: A `DataFrame` containing exactly one column from | |
`california_housing_dataframe` to use as target for validation. | |
Returns: | |
A `LinearRegressor` object trained on the training data. | |
""" | |
periods = 10 | |
steps_per_period = steps / periods | |
# Create a linear regressor object. | |
my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) | |
#Utilize gradient clipping to avoid vanishing or exploding gradients | |
my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0) | |
#Apply the linear regressor to the dataset | |
linear_regressor = tf.estimator.LinearRegressor( | |
feature_columns=construct_feature_columns(training_examples), | |
optimizer=my_optimizer | |
) | |
# Create input functions. | |
training_input_fn = lambda: my_input_fn( | |
training_examples, | |
training_targets["median_house_value"], | |
batch_size=batch_size) | |
predict_training_input_fn = lambda: my_input_fn( | |
training_examples, | |
training_targets["median_house_value"], | |
num_epochs=1, | |
shuffle=False) | |
predict_validation_input_fn = lambda: my_input_fn( | |
validation_examples, validation_targets["median_house_value"], | |
num_epochs=1, | |
shuffle=False) | |
# Train the model, but do so inside a loop so that we can periodically assess | |
# loss metrics. | |
print "Training model..." | |
print "RMSE (on training data):" | |
training_rmse = [] | |
validation_rmse = [] | |
for period in range (0, periods): | |
# Train the model, starting from the prior state. | |
linear_regressor.train( | |
input_fn=training_input_fn, | |
steps=steps_per_period, | |
) | |
# Take a break and compute predictions. | |
training_predictions = linear_regressor.predict(input_fn=predict_training_input_fn) | |
training_predictions = np.array([item['predictions'][0] for item in training_predictions]) | |
validation_predictions = linear_regressor.predict(input_fn=predict_validation_input_fn) | |
validation_predictions = np.array([item['predictions'][0] for item in validation_predictions]) | |
# Compute training and validation loss. | |
training_root_mean_squared_error = math.sqrt( | |
metrics.mean_squared_error(training_predictions, training_targets)) | |
validation_root_mean_squared_error = math.sqrt( | |
metrics.mean_squared_error(validation_predictions, validation_targets)) | |
# Occasionally print the current loss. | |
print " period %02d : %0.2f" % (period, training_root_mean_squared_error) | |
# Add the loss metrics from this period to our list. | |
training_rmse.append(training_root_mean_squared_error) | |
validation_rmse.append(validation_root_mean_squared_error) | |
print "Model training finished." | |
# Output a graph of loss metrics over periods. | |
plt.ylabel("RMSE") | |
plt.xlabel("Periods") | |
plt.title("Root Mean Squared Error vs. Periods") | |
plt.tight_layout() | |
plt.plot(training_rmse, label="training") | |
plt.plot(validation_rmse, label="validation") | |
plt.legend() | |
return linear_regressor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment