Skip to content

Instantly share code, notes, and snippets.

@jeremyjordan
Last active December 4, 2023 13:41
Show Gist options
  • Save jeremyjordan/5a222e04bb78c242f5763ad40626c452 to your computer and use it in GitHub Desktop.
Save jeremyjordan/5a222e04bb78c242f5763ad40626c452 to your computer and use it in GitHub Desktop.
Keras Callback for implementing Stochastic Gradient Descent with Restarts
from keras.callbacks import Callback
import keras.backend as K
import numpy as np
class SGDRScheduler(Callback):
'''Cosine annealing learning rate scheduler with periodic restarts.
# Usage
```python
schedule = SGDRScheduler(min_lr=1e-5,
max_lr=1e-2,
steps_per_epoch=np.ceil(epoch_size/batch_size),
lr_decay=0.9,
cycle_length=5,
mult_factor=1.5)
model.fit(X_train, Y_train, epochs=100, callbacks=[schedule])
```
# Arguments
min_lr: The lower bound of the learning rate range for the experiment.
max_lr: The upper bound of the learning rate range for the experiment.
steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`.
lr_decay: Reduce the max_lr after the completion of each cycle.
Ex. To reduce the max_lr by 20% after each cycle, set this value to 0.8.
cycle_length: Initial number of epochs in a cycle.
mult_factor: Scale epochs_to_restart after each full cycle completion.
# References
Blog post: jeremyjordan.me/nn-learning-rate
Original paper: http://arxiv.org/abs/1608.03983
'''
def __init__(self,
min_lr,
max_lr,
steps_per_epoch,
lr_decay=1,
cycle_length=10,
mult_factor=2):
self.min_lr = min_lr
self.max_lr = max_lr
self.lr_decay = lr_decay
self.batch_since_restart = 0
self.next_restart = cycle_length
self.steps_per_epoch = steps_per_epoch
self.cycle_length = cycle_length
self.mult_factor = mult_factor
self.history = {}
def clr(self):
'''Calculate the learning rate.'''
fraction_to_restart = self.batch_since_restart / (self.steps_per_epoch * self.cycle_length)
lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(fraction_to_restart * np.pi))
return lr
def on_train_begin(self, logs={}):
'''Initialize the learning rate to the minimum value at the start of training.'''
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.max_lr)
def on_batch_end(self, batch, logs={}):
'''Record previous batch statistics and update the learning rate.'''
logs = logs or {}
self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
self.batch_since_restart += 1
K.set_value(self.model.optimizer.lr, self.clr())
def on_epoch_end(self, epoch, logs={}):
'''Check for end of current cycle, apply restarts when necessary.'''
if epoch + 1 == self.next_restart:
self.batch_since_restart = 0
self.cycle_length = np.ceil(self.cycle_length * self.mult_factor)
self.next_restart += self.cycle_length
self.max_lr *= self.lr_decay
self.best_weights = self.model.get_weights()
def on_train_end(self, logs={}):
'''Set weights to the values from the end of the most recent cycle for best performance.'''
self.model.set_weights(self.best_weights)
@Trotts
Copy link

Trotts commented Jun 26, 2020

@jeremyjordan Sorry for coming back to this after so long, but I recently noticed something with the code I am wondering you could explain?

Lets say I have a model that I have trained to 89 epochs previously. I then restart training at a later date and wish to train for another 11 epoch, up to 100. In this case, it seems as though you will never hit a cycle restart unless you specify a cycle length > 89, due to this line:

if epoch + 1 == self.next_restart:

However, with such a high cycle length, you may never again hit another restart.

For example:

number_epochs = 100

if not os.path.exists(model.log_dir):
    os.makedirs(model.log_dir)
    
csv_logger = CSVLogger(os.path.join(model.log_dir, "epoch_logger.csv"), append=True)

# Save the model config to the log directory
print(config_list, file=open(os.path.join(model.log_dir,"config.txt"), 'w'))

# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.

# Cycle length is what you want + epoch restart num 
# (e.g if you want a cycle length of 2, restart epoch is 89, thus cycle length = 91)
schedule = SGDRScheduler(min_lr=1e-5,
                        max_lr=1e-2,
                        steps_per_epoch=np.ceil(number_epochs/config.BATCH_SIZE),
                        lr_decay=0.9,
                        cycle_length=91, 
                        mult_factor=1.5)

model.train(dataset_train, dataset_val, 
                learning_rate=1e-2, 
                epochs=number_epochs, 
                layers='heads',best_only=True, custom_callbacks = [schedule, csv_logger], augmentation= aug1)

Gives:

Starting at epoch 89. LR=0.01

Checkpoint Path: /home/b3020111/dolphin-recognition/Mask_RCNN-master/logs/ndd/above/od/ndd-above-od-1.10-tester20200324T1218/mask_rcnn_ndd-above-od-1.10-tester_{epoch:04d}.h5
Selecting layers to train
fpn_c5p5               (Conv2D)
fpn_c4p4               (Conv2D)
fpn_c3p3               (Conv2D)
fpn_c2p2               (Conv2D)
fpn_p5                 (Conv2D)
fpn_p2                 (Conv2D)
fpn_p3                 (Conv2D)
fpn_p4                 (Conv2D)
In model:  rpn_model
    rpn_conv_shared        (Conv2D)
    rpn_class_raw          (Conv2D)
    rpn_bbox_pred          (Conv2D)
mrcnn_mask_conv1       (TimeDistributed)
mrcnn_mask_bn1         (TimeDistributed)
mrcnn_mask_conv2       (TimeDistributed)
mrcnn_mask_bn2         (TimeDistributed)
mrcnn_class_conv1      (TimeDistributed)
mrcnn_class_bn1        (TimeDistributed)
mrcnn_mask_conv3       (TimeDistributed)
mrcnn_mask_bn3         (TimeDistributed)
mrcnn_class_conv2      (TimeDistributed)
mrcnn_class_bn2        (TimeDistributed)
mrcnn_mask_conv4       (TimeDistributed)
mrcnn_mask_bn4         (TimeDistributed)
mrcnn_bbox_fc          (TimeDistributed)
mrcnn_mask_deconv      (TimeDistributed)
mrcnn_class_logits     (TimeDistributed)
mrcnn_mask             (TimeDistributed)
Epoch 90/100
 99/100 [============================>.] - ETA: 3s - loss: 0.9831 - rpn_class_loss: 0.0082 - rpn_bbox_loss: 0.6570 - mrcnn_class_loss: 0.0171 - mrcnn_bbox_loss: 0.1787 - mrcnn_mask_loss: 0.1214
 epoch end, checking to see if end of cycle...

 epoch + 1 =  90  self.next_restart =  91
100/100 [==============================] - 378s 4s/step - loss: 0.9816 - rpn_class_loss: 0.0083 - rpn_bbox_loss: 0.6560 - mrcnn_class_loss: 0.0171 - mrcnn_bbox_loss: 0.1783 - mrcnn_mask_loss: 0.1213 - val_loss: 1.7121 - val_rpn_class_loss: 0.0104 - val_rpn_bbox_loss: 1.4068 - val_mrcnn_class_loss: 0.0109 - val_mrcnn_bbox_loss: 0.1635 - val_mrcnn_mask_loss: 0.1200
Epoch 91/100
 99/100 [============================>.] - ETA: 2s - loss: 1.0819 - rpn_class_loss: 0.0092 - rpn_bbox_loss: 0.7405 - mrcnn_class_loss: 0.0164 - mrcnn_bbox_loss: 0.1952 - mrcnn_mask_loss: 0.1199
 epoch end, checking to see if end of cycle...

 epoch + 1 =  91  self.next_restart =  91

 cycle finished, saving weights...
Next restart at  228.0
100/100 [==============================] - 300s 3s/step - loss: 1.0881 - rpn_class_loss: 0.0092 - rpn_bbox_loss: 0.7459 - mrcnn_class_loss: 0.0164 - mrcnn_bbox_loss: 0.1957 - mrcnn_mask_loss: 0.1202 - val_loss: 1.2089 - val_rpn_class_loss: 0.0113 - val_rpn_bbox_loss: 0.8753 - val_mrcnn_class_loss: 0.0101 - val_mrcnn_bbox_loss: 0.1912 - val_mrcnn_mask_loss: 0.1204
Epoch 92/100
 99/100 [============================>.] - ETA: 2s - loss: 1.0154 - rpn_class_loss: 0.0100 - rpn_bbox_loss: 0.6705 - mrcnn_class_loss: 0.0160 - mrcnn_bbox_loss: 0.1912 - mrcnn_mask_loss: 0.1269
 epoch end, checking to see if end of cycle...

 epoch + 1 =  92  self.next_restart =  228.0

If the cycle length is changed from 91 to say, 2, then the epoch + 1 check is never fulfilled, and thus a best_weights is never stored causing the code to error at the end of the training.

Am I correct in this interpretation of how the cycle length is working in this code, and if so, is there a way to allow for smaller cycle lengths while still allowing for the restarting of training?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment