Skip to content

Instantly share code, notes, and snippets.

@denised
Created September 8, 2019 08:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denised/82cd41ab8a4c3d1d9abb12ea5e481f36 to your computer and use it in GitHub Desktop.
Save denised/82cd41ab8a4c3d1d9abb12ea5e481f36 to your computer and use it in GitHub Desktop.
Accumulate multiple lr/loss curves in fastai
from matplotlib import pyplot
from fastai.core import ifnone
# See my post on forums.fast.ai for an example
class LRAccumulator(object):
"""Accumulate multiple recorder results to compare them on the same graph. Can be applied across any Learner fit method
(lr_find, fit, etc.), and a single accumulator can be used across multiple learners, models, data... anything where you'd like
to compare the loss graphs."""
def __init__(self, learner=None, title="a", fmt=''):
"""Create a new accumulator, optionally starting with an initial recorder trace."""
self.curves = []
if learner:
self.add(learner, title, fmt)
def add(self, learner, title=None, fmt=''):
"""Add another recorder trace to the list.
The format of the curve can be specified with the fmt argument using the matplotlib format shortcut notation (e.g. 'ro-')"""
title = ifnone(title, chr(ord("a") + len(self.curves)))
self.curves.append( (title, learner.recorder.lrs, [x.item() for x in learner.recorder.losses], fmt) )
def drop(self, index=-1):
"""Add the wrong curve by mistake?"""
del self.curves[index]
def plot(self, bylrs=True, xmin=None,xmax=None,ymin=None,ymax=None):
"""Plot all the accumulated curves. By default, plots loss against learning rate (which is appropriate for comparing lr_find
results). To compare other loss traces, set `bylrs=False`. By default the graph will be scaled to include all the data for
all the curves; use the xmin/max and ymin/max arguments to focus on the interesting part."""
_, ax = pyplot.subplots(1,1)
for (label, xs, ys, fmt) in self.curves:
if bylrs:
ax.plot(xs, ys, fmt, label=label)
else:
ax.plot(ys, fmt, label=label)
ax.set_ylabel("Loss")
ax.set_xlabel("Learning Rate" if bylrs else "Batch")
if xmin is not None or xmax is not None:
ax.set_xlim(left=xmin, right=xmax)
if ymin is not None or ymax is not None:
ax.set_ylim(bottom=ymin, top=ymax)
if bylrs:
ax.set_xscale('log')
ax.xaxis.set_major_formatter(pyplot.FormatStrFormatter('%.0e'))
ax.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment