Created
September 8, 2019 08:53
-
-
Save denised/82cd41ab8a4c3d1d9abb12ea5e481f36 to your computer and use it in GitHub Desktop.
Accumulate multiple lr/loss curves in fastai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from matplotlib import pyplot | |
from fastai.core import ifnone | |
# See my post on forums.fast.ai for an example | |
class LRAccumulator(object): | |
"""Accumulate multiple recorder results to compare them on the same graph. Can be applied across any Learner fit method | |
(lr_find, fit, etc.), and a single accumulator can be used across multiple learners, models, data... anything where you'd like | |
to compare the loss graphs.""" | |
def __init__(self, learner=None, title="a", fmt=''): | |
"""Create a new accumulator, optionally starting with an initial recorder trace.""" | |
self.curves = [] | |
if learner: | |
self.add(learner, title, fmt) | |
def add(self, learner, title=None, fmt=''): | |
"""Add another recorder trace to the list. | |
The format of the curve can be specified with the fmt argument using the matplotlib format shortcut notation (e.g. 'ro-')""" | |
title = ifnone(title, chr(ord("a") + len(self.curves))) | |
self.curves.append( (title, learner.recorder.lrs, [x.item() for x in learner.recorder.losses], fmt) ) | |
def drop(self, index=-1): | |
"""Add the wrong curve by mistake?""" | |
del self.curves[index] | |
def plot(self, bylrs=True, xmin=None,xmax=None,ymin=None,ymax=None): | |
"""Plot all the accumulated curves. By default, plots loss against learning rate (which is appropriate for comparing lr_find | |
results). To compare other loss traces, set `bylrs=False`. By default the graph will be scaled to include all the data for | |
all the curves; use the xmin/max and ymin/max arguments to focus on the interesting part.""" | |
_, ax = pyplot.subplots(1,1) | |
for (label, xs, ys, fmt) in self.curves: | |
if bylrs: | |
ax.plot(xs, ys, fmt, label=label) | |
else: | |
ax.plot(ys, fmt, label=label) | |
ax.set_ylabel("Loss") | |
ax.set_xlabel("Learning Rate" if bylrs else "Batch") | |
if xmin is not None or xmax is not None: | |
ax.set_xlim(left=xmin, right=xmax) | |
if ymin is not None or ymax is not None: | |
ax.set_ylim(bottom=ymin, top=ymax) | |
if bylrs: | |
ax.set_xscale('log') | |
ax.xaxis.set_major_formatter(pyplot.FormatStrFormatter('%.0e')) | |
ax.legend() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment