Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created August 2, 2020 23:18
Show Gist options
  • Save thierryherrmann/6dca3ace95450be3391c5498b63710a6 to your computer and use it in GitHub Desktop.
Save thierryherrmann/6dca3ace95450be3391c5498b63710a6 to your computer and use it in GitHub Desktop.
def train_module(module, train_dataset, valid_dataset):
valid_metric = keras.metrics.MeanSquaredError()
loss_hist = []
step=1
for epoch in range(3):
for X, y in train_dataset:
loss = module.my_train(X, y)
loss_hist.append(loss.numpy())
if step % 100 == 0:
for (X_val, y_val) in valid_dataset:
val_logits = module(X_val)
valid_metric.update_state(y_val, val_logits)
print(f'Mean squared error: step {step}: {valid_metric.result()}')
step+=1
return loss_hist
def plot_loss(loss_hist):
plt.figure(figsize=(8,4))
plt.title('loss', fontsize=15)
plt.plot(loss_hist)
plt.grid()
# train the module
loss_hist = train_module(module, train_dataset, valid_dataset)
plot_loss(loss_hist)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment