Last active
December 18, 2023 14:02
-
-
Save mwitiderrick/73ec6285b6b7ab09d2b0cb00c26d2782 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
rng = jax.random.PRNGKey(0) | |
rng, init_rng = jax.random.split(rng) | |
learning_rate = 1e-5 | |
seed = 0 | |
state = create_train_state(init_rng, learning_rate, momentum) | |
del init_rng # Must not be used anymore. | |
num_epochs = 30 | |
(test_images, test_labels) = next(iter(validation_loader)) | |
test_images = test_images / 255.0 | |
state = create_train_state(jax.random.PRNGKey(seed), learning_rate, momentum) | |
training_loss = [] | |
training_accuracy = [] | |
testing_loss = [] | |
testing_accuracy = [] | |
def train_model(state, train_loader, test_loader, num_epochs=30): | |
# Training loop | |
for epoch in tqdm(range(num_epochs)): | |
train_batch_loss, train_batch_accuracy = [], [] | |
val_batch_loss, val_batch_accuracy = [], [] | |
for train_batch in train_loader: | |
state, train_metrics = train_step(state, train_batch) | |
train_batch_loss.append(train_metrics['loss']) | |
train_batch_accuracy.append(train_metrics['accuracy']) | |
for val_batch in test_loader: | |
test_metrics = eval_step(state, val_batch) | |
val_batch_loss.append(test_metrics['loss']) | |
val_batch_accuracy.append(test_metrics['accuracy']) | |
# Loss for the current epoch | |
epoch_train_loss = np.mean(train_batch_loss) | |
epoch_val_loss = np.mean(val_batch_loss) | |
# Accuracy for the current epoch | |
epoch_train_acc = np.mean(train_batch_accuracy) | |
epoch_val_acc = np.mean(val_batch_accuracy) | |
testing_loss.append(epoch_val_loss) | |
testing_accuracy.append(epoch_val_acc) | |
training_loss.append(epoch_train_loss) | |
training_accuracy.append(epoch_train_acc) | |
print( | |
f"Epoch: {epoch + 1}, loss: {epoch_train_loss:.2f}, acc: {epoch_train_acc:.2f} val loss: {epoch_val_loss:.2f} val acc {epoch_val_acc:.2f} " | |
) | |
return state | |
trained_model_state = train_model( | |
state, train_loader, validation_loader, num_epochs=20 | |
) | |
""" | |
Epoch: 1, loss: 6.41, acc: 0.57 val loss: 2.28 val acc 0.59 | |
Epoch: 2, loss: 1.72, acc: 0.64 val loss: 1.35 val acc 0.66 | |
Epoch: 3, loss: 1.29, acc: 0.68 val loss: 1.43 val acc 0.63 | |
Epoch: 4, loss: 0.98, acc: 0.72 val loss: 1.14 val acc 0.67 | |
Epoch: 5, loss: 0.83, acc: 0.74 val loss: 1.73 val acc 0.63 | |
Epoch: 6, loss: 0.77, acc: 0.76 val loss: 0.95 val acc 0.70 | |
Epoch: 7, loss: 0.58, acc: 0.79 val loss: 1.12 val acc 0.69 | |
Epoch: 8, loss: 0.43, acc: 0.84 val loss: 0.85 val acc 0.72 | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment