Created
August 28, 2018 15:14
-
-
Save koshian2/4a8b27a1368db17d3cd6228ff01c876f to your computer and use it in GitHub Desktop.
Improve implementation of SGDR: Stochastic Gradient Descent with Warm Restarts in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Conv2D, Activation, BatchNormalization, Add, Input, GlobalAveragePooling2D, Dense | |
from keras.models import Model | |
from keras.optimizers import SGD | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
from keras.initializers import he_normal | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.callbacks import Callback, LearningRateScheduler | |
from keras.regularizers import l2 | |
import numpy as np | |
import pickle | |
class LearningRateCallback(Callback): | |
def __init__(self, lr_max, lr_min, lr_max_compression=5, t0=10, tmult=1, trigger_val_acc=0.0, show_lr=True): | |
# Global learning rate max/min | |
self.lr_max = lr_max | |
self.lr_min = lr_min | |
# Max learning rate compression | |
self.lr_max_compression = lr_max_compression | |
# Warm restarts params | |
self.t0 = t0 | |
self.tmult = tmult | |
# Learning rate decay trigger (早い段階で減衰させても訓練が遅くなるだけなので) | |
self.trigger_val_acc = trigger_val_acc | |
# init parameters | |
self.show_lr = show_lr | |
self._init_params() | |
def _init_params(self): | |
# Decay triggered | |
self.triggered = False | |
# Learning rate of next warm up | |
self.lr_warmup_next = self.lr_max | |
self.lr_warmup_current = self.lr_max | |
# Current learning rate | |
self.lr = self.lr_max | |
# Current warm restart interval | |
self.ti = self.t0 | |
# Warm restart count | |
self.tcur = 1 | |
# Best validation accuracy | |
self.best_val_acc = 0 | |
def on_train_begin(self, logs): | |
self._init_params() | |
def on_epoch_end(self, epoch, logs): | |
if not self.triggered and logs["val_acc"] >= self.trigger_val_acc: | |
self.triggered = True | |
if self.triggered: | |
# Update next warmup lr when validation acc surpassed | |
if logs["val_acc"] > self.best_val_acc: | |
self.best_val_acc = logs["val_acc"] | |
# Avoid lr_warmup_next too small | |
if self.lr_max_compression > 0: | |
self.lr_warmup_next = max(self.lr_warmup_current / self.lr_max_compression, self.lr) | |
else: | |
self.lr_warmup_next = self.lr | |
if self.show_lr: | |
print(f"epoch = {epoch+1}, sgdr_triggered = {self.triggered}, best_val_acc = {self.best_val_acc}, " + | |
f"current_lr = {self.lr:f}, next_warmup_lr = {self.lr_warmup_next:f}, next_warmup = {self.ti-self.tcur}") | |
# SGDR | |
def lr_scheduler(self, epoch): | |
if not self.triggered: return self.lr | |
# SGDR | |
self.tcur += 1 | |
if self.tcur > self.ti: | |
self.ti = int(self.tmult * self.ti) | |
self.tcur = 1 | |
self.lr_warmup_current = self.lr_warmup_next | |
self.lr = float(self.lr_min + (self.lr_warmup_current - self.lr_min) * (1 + np.cos(self.tcur/self.ti*np.pi)) / 2.0) | |
return self.lr | |
# WideResNetっぽいなにか(パラメーターは若干違う) | |
class ResNet: | |
def __init__(self, n=4, initial_lr=1e-2, nb_epochs=100): | |
self.n = n | |
self.initial_lr = initial_lr | |
self.nb_epochs = nb_epochs | |
self.weight_decay = 0.0005 | |
# Make model | |
self.model = self.make_model() | |
# Poolingではなくstride=2のConvを使う | |
def subsumpling(self, output_channels, input_tensor): | |
return Conv2D(output_channels, kernel_size=1, strides=(2,2), kernel_regularizer=l2(self.weight_decay))(input_tensor) | |
# BN->ReLU->Conv->BN->ReLU->Conv をショートカットさせる(pre-act ResNet) | |
def block(self, channles, input_tensor): | |
# ショートカット元 | |
shortcut = input_tensor | |
# メイン側 | |
x = BatchNormalization()(input_tensor) | |
x = Activation("relu")(x) | |
x = Conv2D(channles, kernel_size=3, padding="same", kernel_regularizer=l2(self.weight_decay))(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = Conv2D(channles, kernel_size=3, padding="same", kernel_regularizer=l2(self.weight_decay))(x) | |
# 結合 | |
return Add()([x, shortcut]) | |
def make_model(self): | |
input = Input(shape=(32, 32, 3)) | |
# 3->128にチャンネル数を増やす | |
x = Conv2D(80, kernel_size=1, padding="same", kernel_regularizer=l2(self.weight_decay))(input) | |
# 32x32x80のブロックをn回 | |
for i in range(self.n): | |
x = self.block(80, x) | |
# 16x16x160 | |
x = self.subsumpling(160, x) | |
for i in range(self.n): | |
x = self.block(160, x) | |
# 8x8x320 | |
x = self.subsumpling(320, x) | |
for i in range(self.n): | |
x = self.block(320, x) | |
# Global Average Pooling | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(10, activation="softmax")(x) | |
# model | |
model = Model(input, x) | |
return model | |
# よくある学習率調整(よくあるやつは0.5と0.75) | |
def state_of_art_lr_scheduler(self, epoch): | |
x = self.initial_lr | |
if epoch >= self.nb_epochs * 0.5: x /= 10.0 | |
if epoch >= self.nb_epochs * 0.75: x /= 10.0 | |
return x | |
def train(self, X_train, y_train, X_val, y_val): | |
# コンパイル | |
self.model.compile(optimizer=SGD(lr=self.initial_lr, momentum=0.9), loss="categorical_crossentropy", metrics=["acc"]) | |
# Data Augmentation | |
traingen = ImageDataGenerator( | |
width_shift_range=4./32, | |
height_shift_range=4./32, | |
horizontal_flip=True) | |
# Learning rate callback | |
lr_cbs = LearningRateCallback(self.initial_lr, self.initial_lr/100, 100, 120, 1, 0.85) | |
sgdr = LearningRateScheduler(lr_cbs.lr_scheduler) | |
#non_sgdr = LearningRateScheduler(self.state_of_art_lr_scheduler) | |
# Train | |
history = self.model.fit_generator(traingen.flow(X_train, y_train, batch_size=128), epochs=self.nb_epochs, | |
steps_per_epoch=len(X_train)/128, validation_data=(X_val, y_val), | |
callbacks=[lr_cbs, sgdr]).history | |
# Save history | |
with open("history_sgdr.dat", "wb") as fp: | |
pickle.dump(history, fp) | |
# Main function | |
def main(nb_epochs): | |
net = ResNet(nb_epochs=nb_epochs) | |
# CIFAR | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
# normalize | |
X_train = X_train / 255.0 | |
X_test = X_test / 255.0 | |
# channel-wise-normalizing | |
cifar_mean = np.mean(X_train, axis=(0,1,2)).reshape(1,1,1,3) | |
cifar_sd = np.std(X_train, axis=(0,1,2)).reshape(1,1,1,3) | |
X_train = (X_train - cifar_mean) / cifar_sd | |
X_test = (X_test - cifar_mean) / cifar_sd | |
# train | |
net.train(X_train, y_train, X_test, y_test) | |
if __name__ == "__main__": | |
main(150) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment