This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def norm(x): | |
return (x - train_stats['mean']) / train_stats['std'] | |
normed_train_data = norm(train_dataset) | |
normed_test_data = norm(test_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_model(): | |
model = keras.Sequential([ | |
layers.Dense(64, activation=tf.nn.relu, input_shape=[len(train_dataset.keys())]), | |
layers.Dense(64, activation=tf.nn.relu), | |
layers.Dense(1) | |
]) | |
optimizer = tf.keras.optimizers.RMSprop(0.001) | |
model.compile(loss='mean_squared_error', |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model.summary() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
example_batch = normed_train_data[:10] | |
example_result = model.predict(example_batch) | |
example_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Display training progress by printing a single dot for each completed epoch | |
class PrintDot(keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs): | |
if epoch % 100 == 0: print('') | |
print('.', end='') | |
EPOCHS = 1000 | |
history = model.fit( | |
normed_train_data, train_labels, | |
epochs=EPOCHS, validation_split = 0.2, verbose=0, | |
callbacks=[PrintDot()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hist = pd.DataFrame(history.history) | |
hist['epoch'] = history.epoch | |
hist.tail() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_history(history): | |
hist = pd.DataFrame(history.history) | |
hist['epoch'] = history.epoch | |
plt.figure() | |
plt.xlabel('Epoch') | |
plt.ylabel('Mean Abs Error [MPG]') | |
plt.plot(hist['epoch'], hist['mean_absolute_error'], | |
label='Train Error') | |
plt.plot(hist['epoch'], hist['val_mean_absolute_error'], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = build_model() | |
# The patience parameter is the amount of epochs to check for improvement | |
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10) | |
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS, | |
validation_split = 0.2, verbose=0, callbacks=[early_stop, PrintDot()]) | |
plot_history(history) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
loss, mae, mse = model.evaluate(normed_test_data, test_labels, verbose=0) | |
print("Testing set Mean Abs Error: {:5.2f} MPG".format(mae)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_predictions = model.predict(normed_test_data).flatten() | |
plt.scatter(test_labels, test_predictions) | |
plt.xlabel('True Values [MPG]') | |
plt.ylabel('Predictions [MPG]') | |
plt.axis('equal') | |
plt.axis('square') | |
plt.xlim([0,plt.xlim()[1]]) | |
plt.ylim([0,plt.ylim()[1]]) | |
_ = plt.plot([-100, 100], [-100, 100]) |