Skip to content

Instantly share code, notes, and snippets.

@matpalm
Last active February 3, 2020 01:22
Show Gist options
  • Save matpalm/555f99cee4391e4afdacfba05bb29637 to your computer and use it in GitHub Desktop.
Save matpalm/555f99cee4391e4afdacfba05bb29637 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import model as m
from tensorflow.keras.callbacks import *
import data as d
import tensorflow as tf
import os
from lr_finder import LearningRateFinder
tf.config.optimizer.set_jit(True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--train-tf-record-glob', type=str, required=True)
parser.add_argument('--num-batches', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--initial-learning-rate', type=float, default=1e-10)
parser.add_argument('--final-learning-rate', type=float, default=1e-1)
parser.add_argument('--shuffle-buffer-size', type=int, default=64)
parser.add_argument('--expit-squash', type=float, default=1.0)
parser.add_argument('--plus-one-weight', type=float, default=5.0)
parser.add_argument('--self-labelled-weight', type=float, default=1.0)
opts = parser.parse_args()
train_dataset = d.dataset_from_tfrecord(opts.train_tf_record_glob,
batch_size=opts.batch_size,
plus_one_weight=opts.plus_one_weight,
self_labelled_weight=opts.self_labelled_weight,
expit_squash=opts.expit_squash,
training=True,
shuffle_buffer=opts.shuffle_buffer_size)
model = m.construct_model(learning_rate=1e-4)
finder = LearningRateFinder(model)
finder.find(train_dataset,
initial_learning_rate=opts.initial_learning_rate,
final_learning_rate=opts.final_learning_rate,
num_batches=opts.num_batches)
finder.export_plot(fname="/tmp/foo.png")
import math
from matplotlib import pyplot as plt
import tensorflow.keras.backend as K
class LearningRateFinder:
def __init__(self, model):
self.model = model
self.losses = []
self.learning_rates = []
def find(self, train_dataset, num_batches, initial_learning_rate, final_learning_rate):
learning_rate_ratio = final_learning_rate / initial_learning_rate
self.learning_rate_multiplier = learning_rate_ratio ** (1 / num_batches)
print("num_batches", num_batches)
print("self.learning_rate_multiplier", self.learning_rate_multiplier)
K.set_value(self.model.optimizer.lr, initial_learning_rate)
print("initial_learning_rate", initial_learning_rate)
# TODO: we make the assumption that there is enough data to
# take at _least_ num_batches
for i, (imgs, labels, _weights) in enumerate(train_dataset.take(num_batches)):
# TODO: no sample_weight on train_on_batch in this keras version?
loss, _accuracy = self.model.train_on_batch(imgs, labels)
if math.isnan(loss):
break
self.losses.append(loss)
learning_rate = K.get_value(self.model.optimizer.learning_rate)
self.learning_rates.append(learning_rate)
print("%d/%d learning_rate=%s loss=%s" % (i, num_batches, learning_rate, loss))
learning_rate *= self.learning_rate_multiplier
K.set_value(self.model.optimizer.lr, learning_rate)
def export_plot(self, fname):
plt.ylabel("loss")
plt.xlabel("learning rate")
plt.plot(self.learning_rates, self.losses)
plt.xscale('log')
plt.savefig(fname)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment