Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created December 7, 2018 09:27
Show Gist options
  • Save devforfu/252b8b0e92c6107e31ded7a01cdf9c34 to your computer and use it in GitHub Desktop.
Save devforfu/252b8b0e92c6107e31ded7a01cdf9c34 to your computer and use it in GitHub Desktop.
Smoothed loss callback
class RollingLoss(Callback):
def __init__(self, smooth=0.98):
self.smooth = smooth
def batch_ended(self, phase, **kwargs):
prev = phase.rolling_loss
a = self.smooth
avg_loss = a * prev + (1 - a) * phase.batch_loss
debias_loss = avg_loss / (1 - a ** phase.batch_index)
phase.rolling_loss = avg_loss
phase.update(debias_loss)
def epoch_ended(self, phases, **kwargs):
for phase in phases:
phase.update_metric('loss', phase.last_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment