Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Last active December 18, 2023 14:02
Show Gist options
  • Save mwitiderrick/73ec6285b6b7ab09d2b0cb00c26d2782 to your computer and use it in GitHub Desktop.
Save mwitiderrick/73ec6285b6b7ab09d2b0cb00c26d2782 to your computer and use it in GitHub Desktop.
import jax
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
learning_rate = 1e-5
seed = 0
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng # Must not be used anymore.
num_epochs = 30
(test_images, test_labels) = next(iter(validation_loader))
test_images = test_images / 255.0
state = create_train_state(jax.random.PRNGKey(seed), learning_rate, momentum)
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []
def train_model(state, train_loader, test_loader, num_epochs=30):
# Training loop
for epoch in tqdm(range(num_epochs)):
train_batch_loss, train_batch_accuracy = [], []
val_batch_loss, val_batch_accuracy = [], []
for train_batch in train_loader:
state, train_metrics = train_step(state, train_batch)
train_batch_loss.append(train_metrics['loss'])
train_batch_accuracy.append(train_metrics['accuracy'])
for val_batch in test_loader:
test_metrics = eval_step(state, val_batch)
val_batch_loss.append(test_metrics['loss'])
val_batch_accuracy.append(test_metrics['accuracy'])
# Loss for the current epoch
epoch_train_loss = np.mean(train_batch_loss)
epoch_val_loss = np.mean(val_batch_loss)
# Accuracy for the current epoch
epoch_train_acc = np.mean(train_batch_accuracy)
epoch_val_acc = np.mean(val_batch_accuracy)
testing_loss.append(epoch_val_loss)
testing_accuracy.append(epoch_val_acc)
training_loss.append(epoch_train_loss)
training_accuracy.append(epoch_train_acc)
print(
f"Epoch: {epoch + 1}, loss: {epoch_train_loss:.2f}, acc: {epoch_train_acc:.2f} val loss: {epoch_val_loss:.2f} val acc {epoch_val_acc:.2f} "
)
return state
trained_model_state = train_model(
state, train_loader, validation_loader, num_epochs=20
)
"""
Epoch: 1, loss: 6.41, acc: 0.57 val loss: 2.28 val acc 0.59
Epoch: 2, loss: 1.72, acc: 0.64 val loss: 1.35 val acc 0.66
Epoch: 3, loss: 1.29, acc: 0.68 val loss: 1.43 val acc 0.63
Epoch: 4, loss: 0.98, acc: 0.72 val loss: 1.14 val acc 0.67
Epoch: 5, loss: 0.83, acc: 0.74 val loss: 1.73 val acc 0.63
Epoch: 6, loss: 0.77, acc: 0.76 val loss: 0.95 val acc 0.70
Epoch: 7, loss: 0.58, acc: 0.79 val loss: 1.12 val acc 0.69
Epoch: 8, loss: 0.43, acc: 0.84 val loss: 0.85 val acc 0.72
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment