Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created August 28, 2018 15:14
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save koshian2/4a8b27a1368db17d3cd6228ff01c876f to your computer and use it in GitHub Desktop.
Save koshian2/4a8b27a1368db17d3cd6228ff01c876f to your computer and use it in GitHub Desktop.
Improve implementation of SGDR: Stochastic Gradient Descent with Warm Restarts in Keras
from keras.layers import Conv2D, Activation, BatchNormalization, Add, Input, GlobalAveragePooling2D, Dense
from keras.models import Model
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras.initializers import he_normal
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import Callback, LearningRateScheduler
from keras.regularizers import l2
import numpy as np
import pickle
class LearningRateCallback(Callback):
def __init__(self, lr_max, lr_min, lr_max_compression=5, t0=10, tmult=1, trigger_val_acc=0.0, show_lr=True):
# Global learning rate max/min
self.lr_max = lr_max
self.lr_min = lr_min
# Max learning rate compression
self.lr_max_compression = lr_max_compression
# Warm restarts params
self.t0 = t0
self.tmult = tmult
# Learning rate decay trigger (早い段階で減衰させても訓練が遅くなるだけなので)
self.trigger_val_acc = trigger_val_acc
# init parameters
self.show_lr = show_lr
self._init_params()
def _init_params(self):
# Decay triggered
self.triggered = False
# Learning rate of next warm up
self.lr_warmup_next = self.lr_max
self.lr_warmup_current = self.lr_max
# Current learning rate
self.lr = self.lr_max
# Current warm restart interval
self.ti = self.t0
# Warm restart count
self.tcur = 1
# Best validation accuracy
self.best_val_acc = 0
def on_train_begin(self, logs):
self._init_params()
def on_epoch_end(self, epoch, logs):
if not self.triggered and logs["val_acc"] >= self.trigger_val_acc:
self.triggered = True
if self.triggered:
# Update next warmup lr when validation acc surpassed
if logs["val_acc"] > self.best_val_acc:
self.best_val_acc = logs["val_acc"]
# Avoid lr_warmup_next too small
if self.lr_max_compression > 0:
self.lr_warmup_next = max(self.lr_warmup_current / self.lr_max_compression, self.lr)
else:
self.lr_warmup_next = self.lr
if self.show_lr:
print(f"epoch = {epoch+1}, sgdr_triggered = {self.triggered}, best_val_acc = {self.best_val_acc}, " +
f"current_lr = {self.lr:f}, next_warmup_lr = {self.lr_warmup_next:f}, next_warmup = {self.ti-self.tcur}")
# SGDR
def lr_scheduler(self, epoch):
if not self.triggered: return self.lr
# SGDR
self.tcur += 1
if self.tcur > self.ti:
self.ti = int(self.tmult * self.ti)
self.tcur = 1
self.lr_warmup_current = self.lr_warmup_next
self.lr = float(self.lr_min + (self.lr_warmup_current - self.lr_min) * (1 + np.cos(self.tcur/self.ti*np.pi)) / 2.0)
return self.lr
# WideResNetっぽいなにか(パラメーターは若干違う)
class ResNet:
def __init__(self, n=4, initial_lr=1e-2, nb_epochs=100):
self.n = n
self.initial_lr = initial_lr
self.nb_epochs = nb_epochs
self.weight_decay = 0.0005
# Make model
self.model = self.make_model()
# Poolingではなくstride=2のConvを使う
def subsumpling(self, output_channels, input_tensor):
return Conv2D(output_channels, kernel_size=1, strides=(2,2), kernel_regularizer=l2(self.weight_decay))(input_tensor)
# BN->ReLU->Conv->BN->ReLU->Conv をショートカットさせる(pre-act ResNet)
def block(self, channles, input_tensor):
# ショートカット元
shortcut = input_tensor
# メイン側
x = BatchNormalization()(input_tensor)
x = Activation("relu")(x)
x = Conv2D(channles, kernel_size=3, padding="same", kernel_regularizer=l2(self.weight_decay))(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(channles, kernel_size=3, padding="same", kernel_regularizer=l2(self.weight_decay))(x)
# 結合
return Add()([x, shortcut])
def make_model(self):
input = Input(shape=(32, 32, 3))
# 3->128にチャンネル数を増やす
x = Conv2D(80, kernel_size=1, padding="same", kernel_regularizer=l2(self.weight_decay))(input)
# 32x32x80のブロックをn回
for i in range(self.n):
x = self.block(80, x)
# 16x16x160
x = self.subsumpling(160, x)
for i in range(self.n):
x = self.block(160, x)
# 8x8x320
x = self.subsumpling(320, x)
for i in range(self.n):
x = self.block(320, x)
# Global Average Pooling
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation="softmax")(x)
# model
model = Model(input, x)
return model
# よくある学習率調整(よくあるやつは0.5と0.75)
def state_of_art_lr_scheduler(self, epoch):
x = self.initial_lr
if epoch >= self.nb_epochs * 0.5: x /= 10.0
if epoch >= self.nb_epochs * 0.75: x /= 10.0
return x
def train(self, X_train, y_train, X_val, y_val):
# コンパイル
self.model.compile(optimizer=SGD(lr=self.initial_lr, momentum=0.9), loss="categorical_crossentropy", metrics=["acc"])
# Data Augmentation
traingen = ImageDataGenerator(
width_shift_range=4./32,
height_shift_range=4./32,
horizontal_flip=True)
# Learning rate callback
lr_cbs = LearningRateCallback(self.initial_lr, self.initial_lr/100, 100, 120, 1, 0.85)
sgdr = LearningRateScheduler(lr_cbs.lr_scheduler)
#non_sgdr = LearningRateScheduler(self.state_of_art_lr_scheduler)
# Train
history = self.model.fit_generator(traingen.flow(X_train, y_train, batch_size=128), epochs=self.nb_epochs,
steps_per_epoch=len(X_train)/128, validation_data=(X_val, y_val),
callbacks=[lr_cbs, sgdr]).history
# Save history
with open("history_sgdr.dat", "wb") as fp:
pickle.dump(history, fp)
# Main function
def main(nb_epochs):
net = ResNet(nb_epochs=nb_epochs)
# CIFAR
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
# normalize
X_train = X_train / 255.0
X_test = X_test / 255.0
# channel-wise-normalizing
cifar_mean = np.mean(X_train, axis=(0,1,2)).reshape(1,1,1,3)
cifar_sd = np.std(X_train, axis=(0,1,2)).reshape(1,1,1,3)
X_train = (X_train - cifar_mean) / cifar_sd
X_test = (X_test - cifar_mean) / cifar_sd
# train
net.train(X_train, y_train, X_test, y_test)
if __name__ == "__main__":
main(150)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment