Skip to content

Instantly share code, notes, and snippets.

@hiromis
Created March 31, 2019 13:54
Show Gist options
  • Save hiromis/8cf83deccdbd8f87587cf0868c7463a9 to your computer and use it in GitHub Desktop.
Save hiromis/8cf83deccdbd8f87587cf0868c7463a9 to your computer and use it in GitHub Desktop.
git/fastai_docs/dev_course/dl2/tmp_TEST2_04_callbacks.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%load_ext autoreload\n%autoreload 2\n\n%matplotlib inline",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nfrom exp.nb_03 import *",
"execution_count": 2,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## DataBunch/Learner"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x_train,y_train,x_valid,y_valid = get_data()\ntrain_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\nnh,bs = 50,64\nc = y_train.max().item()+1\nloss_func = F.cross_entropy",
"execution_count": 3,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Factor out the connected pieces of info out of the fit() argument list\n\n`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`\n\nLet's replace it with something that looks like this:\n\n`fit(1, learn)`\n\nThis will allow us to tweak what's happening inside the training loop in other places of the code because the `Learner` object will be mutable, so changing any of its attribute elsewhere will be seen in our training loop."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nclass DataBunch():\n def __init__(self, train_dl, valid_dl, c=None):\n self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c\n \n @property\n def train_ds(self): return self.train_dl.dataset\n \n @property\n def valid_ds(self): return self.valid_dl.dataset",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\ndef get_model(data, lr=0.5, nh=50):\n m = data.train_ds.x.shape[1]\n model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))\n return model, optim.SGD(model.parameters(), lr=lr)\n\nclass Learner():\n def __init__(self, model, opt, loss_func, data):\n self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(*get_model(data), loss_func, data)",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def fit(epochs, learn):\n for epoch in range(epochs):\n learn.model.train()\n for xb,yb in learn.data.train_dl:\n loss = learn.loss_func(learn.model(xb), yb)\n loss.backward()\n learn.opt.step()\n learn.opt.zero_grad()\n\n learn.model.eval()\n with torch.no_grad():\n tot_loss,tot_acc = 0.,0.\n for xb,yb in learn.data.valid_dl:\n pred = learn.model(xb)\n tot_loss += learn.loss_func(pred, yb)\n tot_acc += accuracy (pred,yb)\n nv = len(learn.data.valid_dl)\n print(epoch, tot_loss/nv, tot_acc/nv)\n return tot_loss/nv, tot_acc/nv",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "loss,acc = fit(1, learn)",
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": "0 tensor(0.1817) tensor(0.9450)\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## CallbackHandler"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This was our training loop (without validation) from the previous notebook, with the inner loop contents factored out:\n\n```python\ndef one_batch(xb,yb):\n pred = model(xb)\n loss = loss_func(pred, yb)\n loss.backward()\n opt.step()\n opt.zero_grad()\n \ndef fit():\n for epoch in range(epochs):\n for b in train_dl: one_batch(*b)\n```"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Add callbacks so we can remove complexity from loop, and make it flexible:"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def one_batch(xb, yb, cb):\n if not cb.begin_batch(xb,yb): return\n loss = cb.learn.loss_func(cb.learn.model(xb), yb)\n if not cb.after_loss(loss): return\n loss.backward()\n if cb.after_backward(): cb.learn.opt.step()\n if cb.after_step(): cb.learn.opt.zero_grad()\n\ndef all_batches(dl, cb):\n for xb,yb in dl:\n one_batch(xb, yb, cb)\n if cb.do_stop(): return\n\ndef fit(epochs, learn, cb):\n if not cb.begin_fit(learn): return\n for epoch in range(epochs):\n if not cb.begin_epoch(epoch): continue\n all_batches(learn.data.train_dl, cb)\n \n if cb.begin_validate():\n with torch.no_grad(): all_batches(learn.data.valid_dl, cb)\n if cb.do_stop() or not cb.after_epoch(): break\n cb.after_fit()",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class Callback():\n def begin_fit(self, learn):\n self.learn = learn\n return True\n def after_fit(self): return True\n def begin_epoch(self, epoch):\n self.epoch=epoch\n return True\n def begin_validate(self): return True\n def after_epoch(self): return True\n def begin_batch(self, xb, yb):\n self.xb,self.yb = xb,yb\n return True\n def after_loss(self, loss):\n self.loss = loss\n return True\n def after_backward(self): return True\n def after_step(self): return True",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class CallbackHandler():\n def __init__(self,cbs=None):\n self.cbs = cbs if cbs else []\n\n def begin_fit(self, learn):\n self.learn,self.in_train = learn,True\n learn.stop = False\n res = True\n for cb in self.cbs: res = res and cb.begin_fit(learn)\n return res\n\n def after_fit(self):\n res = not self.in_train\n for cb in self.cbs: res = res and cb.after_fit()\n return res\n \n def begin_epoch(self, epoch):\n learn.model.train()\n self.in_train=True\n res = True\n for cb in self.cbs: res = res and cb.begin_epoch(epoch)\n return res\n\n def begin_validate(self):\n self.learn.model.eval()\n self.in_train=False\n res = True\n for cb in self.cbs: res = res and cb.begin_validate()\n return res\n\n def after_epoch(self):\n res = True\n for cb in self.cbs: res = res and cb.after_epoch()\n return res\n \n def begin_batch(self, xb, yb):\n res = True\n for cb in self.cbs: res = res and cb.begin_batch(xb, yb)\n return res\n\n def after_loss(self, loss):\n res = self.in_train\n for cb in self.cbs: res = res and cb.after_loss(loss)\n return res\n\n def after_backward(self):\n res = True\n for cb in self.cbs: res = res and cb.after_backward()\n return res\n\n def after_step(self):\n res = True\n for cb in self.cbs: res = res and cb.after_step()\n return res\n \n def do_stop(self):\n try: return learn.stop\n finally: learn.stop = False",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class TestCallback(Callback):\n def begin_fit(self,learn):\n super().begin_fit(learn)\n self.n_iters = 0\n return True\n \n def after_step(self):\n self.n_iters += 1\n print(self.n_iters)\n if self.n_iters>=10: learn.stop = True\n return True",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "fit(1, learn, cb=CallbackHandler([TestCallback()]))",
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Runner"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nimport re\n\n_camel_re1 = re.compile('(.)([A-Z][a-z]+)')\n_camel_re2 = re.compile('([a-z0-9])([A-Z])')\ndef camel2snake(name):\n s1 = re.sub(_camel_re1, r'\\1_\\2', name)\n return re.sub(_camel_re2, r'\\1_\\2', s1).lower()\n\nclass Callback():\n _order=0\n def set_runner(self, run): self.run=run\n def __getattr__(self, k): return getattr(self.run, k)\n @property\n def name(self):\n name = re.sub(r'Callback$', '', self.__class__.__name__)\n return camel2snake(name or 'callback')",
"execution_count": 15,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This first callback is reponsible to switch the model back and forth in training or validation mode, as well as maintaining a count of the iterations, or the percentage of iterations ellapsed in the epoch."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nclass TrainEvalCallback(Callback):\n def begin_fit(self):\n self.run.n_epochs=0.\n self.run.n_iter=0\n \n def after_batch(self):\n if not self.in_train: return\n self.run.n_epochs += 1./self.iters\n self.run.n_iter += 1\n \n def begin_epoch(self):\n self.run.n_epochs=self.epoch\n self.model.train()\n self.run.in_train=True\n\n def begin_validate(self):\n self.model.eval()\n self.run.in_train=False",
"execution_count": 16,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We'll also re-create our TestCallback"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class TestCallback(Callback):\n _order=1\n def after_step(self):\n if self.n_iter>=10:\n self.run.stop=True # CHANGED\n return True\n def after_epoch(self): # ADDED\n return True",
"execution_count": 17,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "cbname = 'TrainEvalCallback'\ncamel2snake(cbname)",
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 18,
"data": {
"text/plain": "'train_eval_callback'"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "TrainEvalCallback().name",
"execution_count": 19,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 19,
"data": {
"text/plain": "'train_eval'"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nfrom typing import *\n\ndef listify(o):\n if o is None: return []\n if isinstance(o, list): return o\n if isinstance(o, Iterable): return list(o)\n return [o]",
"execution_count": 20,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nclass Runner():\n def __init__(self, cbs=None, cb_funcs=None):\n cbs = listify(cbs)\n for cbf in listify(cb_funcs):\n cb = cbf()\n setattr(self, cb.name, cb)\n cbs.append(cb)\n self.stop,self.cbs = False,[TrainEvalCallback()]+cbs\n\n @property\n def opt(self): return self.learn.opt\n @property\n def model(self): return self.learn.model\n @property\n def loss_func(self): return self.learn.loss_func\n @property\n def data(self): return self.learn.data\n\n def one_batch(self, xb, yb):\n self.xb,self.yb = xb,yb\n if self('begin_batch'): return\n self.pred = self.model(self.xb)\n if self('after_pred'): return\n self.loss = self.loss_func(self.pred, self.yb)\n if self('after_loss') or not self.in_train: return\n self.loss.backward()\n if self('after_backward'): return\n self.opt.step()\n if self('after_step'): return\n self.opt.zero_grad()\n\n def all_batches(self, dl):\n self.iters = len(dl)\n for xb,yb in dl:\n print(f'n_iter = {self.n_iter} / n_epochs = {self.n_epochs:.3f}') # CHANGED\n if self.stop: \n print('>>> BREAK') # CHANGED\n break\n self.one_batch(xb, yb)\n self('after_batch')\n self.stop=False\n\n def fit(self, epochs, learn):\n self.epochs,self.learn,self.loss = epochs,learn,tensor(0.)\n\n try:\n for cb in self.cbs: cb.set_runner(self)\n if self('begin_fit'): return\n for epoch in range(epochs):\n self.epoch = epoch\n if not self('begin_epoch'): self.all_batches(self.data.train_dl)\n\n with torch.no_grad(): \n if not self('begin_validate'): self.all_batches(self.data.valid_dl)\n if self('after_epoch'): break\n \n finally:\n self('after_fit')\n self.learn = None\n\n def __call__(self, cb_name):\n for cb in sorted(self.cbs, key=lambda x: x._order):\n f = getattr(cb, cb_name, None)\n if f and f(): return True\n return False",
"execution_count": 21,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Third callback: how to compute metrics."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nclass AvgStats():\n def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train\n \n def reset(self):\n self.tot_loss,self.count = 0.,0\n self.tot_mets = [0.] * len(self.metrics)\n \n @property\n def all_stats(self): return [self.tot_loss.item()] + self.tot_mets\n @property\n def avg_stats(self): return [o/self.count for o in self.all_stats]\n \n def __repr__(self):\n if not self.count: return \"\"\n return f\"{'train' if self.in_train else 'valid'}: {self.avg_stats}\"\n\n def accumulate(self, run):\n bn = run.xb.shape[0]\n self.tot_loss += run.loss * bn\n self.count += bn\n for i,m in enumerate(self.metrics):\n self.tot_mets[i] += m(run.pred, run.yb) * bn\n\nclass AvgStatsCallback(Callback):\n _order=2 # CHANGED\n def __init__(self, metrics):\n self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)\n \n def begin_epoch(self):\n self.train_stats.reset()\n self.valid_stats.reset()\n \n def after_loss(self):\n stats = self.train_stats if self.in_train else self.valid_stats\n with torch.no_grad(): stats.accumulate(self.run)\n \n def after_epoch(self):\n print(self.train_stats)\n print(self.valid_stats)",
"execution_count": 22,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = Learner(*get_model(data), loss_func, data)",
"execution_count": 23,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# stats = AvgStatsCallback([accuracy])\nstats = [TestCallback(), AvgStatsCallback([accuracy])]\nrun = Runner(cbs=stats)",
"execution_count": 24,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "run.fit(2, learn)",
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"text": "n_iter = 0 / n_epochs = 0.000\nn_iter = 1 / n_epochs = 0.001\nn_iter = 2 / n_epochs = 0.003\nn_iter = 3 / n_epochs = 0.004\nn_iter = 4 / n_epochs = 0.005\nn_iter = 5 / n_epochs = 0.006\nn_iter = 6 / n_epochs = 0.008\nn_iter = 7 / n_epochs = 0.009\nn_iter = 8 / n_epochs = 0.010\nn_iter = 9 / n_epochs = 0.012\nn_iter = 10 / n_epochs = 0.013\n>>>>>> n_iter>=10 (10)\nn_iter = 11 / n_epochs = 0.014\n>>> BREAK\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\nn_iter = 11 / n_epochs = 0.014\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "loss,acc = stats.valid_stats.avg_stats\nassert acc>0.9\nloss,acc",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#export\nfrom functools import partial",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "acc_cbf = partial(AvgStatsCallback,accuracy)",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "run = Runner(cb_funcs=acc_cbf)",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "run.fit(1, learn)",
"execution_count": null,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Using Jupyter means we can get tab-completion even for dynamic code like this! :)"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "run.avg_stats.valid_stats.avg_stats",
"execution_count": null,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Export"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "!python notebook2script.py 04_callbacks.ipynb",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.2",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "git/fastai_docs/dev_course/dl2/tmp_TEST2_04_callbacks.ipynb",
"public": false
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment