Skip to content

Instantly share code, notes, and snippets.

@takotab
Last active January 28, 2020 15:35
Show Gist options
  • Save takotab/0d69b6526f8b592101a9c6a29682473d to your computer and use it in GitHub Desktop.
Save takotab/0d69b6526f8b592101a9c6a29682473d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/tako/dev/env37/lib/python3.7/site-packages/pandas/compat/__init__.py:85: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"from fastai2.torch_basics import *\n",
"from fastai2.data.all import *\n",
"from fastai2.text.core import *\n",
"from fastai2.text.data import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def _maybe_first(o): return o[0] if isinstance(o, tuple) else o\n",
"def _get_lengths(ds):\n",
" tok = _get_tokenizer(ds)\n",
" if tok is None: return\n",
" return tok.get_lengths(ds.items)\n",
"def _get_tokenizer(ds):\n",
" tok = getattr(ds, 'tokenizer', None)\n",
" if isinstance(tok, Tokenizer): return tok\n",
" if isinstance(tok, (list,L)):\n",
" for t in tok:\n",
" if isinstance(t, Tokenizer): return t\n",
" \n",
"@delegates()\n",
"class LMDataLoader(TfmdDL):\n",
" def __init__(self, dataset, lens=None, cache=2, bs=64, seq_len=72, num_workers=0, **kwargs):\n",
" self.items = ReindexCollection(dataset, cache=cache, tfm=_maybe_first)\n",
" self.seq_len = seq_len\n",
" if lens is None: lens = _get_lengths(dataset)\n",
" if lens is None: lens = [len(o) for o in self.items]\n",
" self.lens = ReindexCollection(lens, idxs=self.items.idxs)\n",
" # The \"-1\" is to allow for final label, we throw away the end that's less than bs\n",
" corpus = round_multiple(sum(lens)-1, bs, round_down=True)\n",
" self.bl = corpus//bs #bl stands for batch length\n",
" self.n_batches = self.bl//(seq_len) + int(self.bl%seq_len!=0)\n",
" self.last_len = self.bl - (self.n_batches-1)*seq_len\n",
" self.make_chunks()\n",
" super().__init__(dataset=dataset, bs=bs, num_workers=num_workers, **kwargs)\n",
" self.n = self.n_batches*bs\n",
"\n",
" @delegates(DataLoader.new)\n",
" def new(self, dataset=None, cls=None, **kwargs):\n",
" res = super().new(dataset, cls, **kwargs) \n",
" res.seq_len = self.seq_len\n",
" res.items = self.items\n",
" res.lens = self.lens \n",
" res.bl = self.bl\n",
" res.n_batches = self.n_batches\n",
" res.last_len = self.last_len\n",
" res.make_chunks() \n",
" res.n = self.n_batches*bs \n",
" return res\n",
" \n",
" def make_chunks(self): self.chunks = Chunks(self.items, self.lens)\n",
" def shuffle_fn(self,idxs):\n",
" self.items.shuffle()\n",
" self.make_chunks()\n",
" return idxs\n",
"\n",
" def create_item(self, seq):\n",
" if seq>=self.n: raise IndexError\n",
" sl = self.last_len if seq//self.bs==self.n_batches-1 else self.seq_len\n",
" st = (seq%self.bs)*self.bl + (seq//self.bs)*self.seq_len\n",
" txt = self.chunks[st : st+sl+1]\n",
" return LMTensorText(txt[:-1]),txt[1:]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"bs,sl = 4,3\n",
"ints = L([0,1,2,3,4],[5,6,7,8,9,10],[11,12,13,14,15,16,17,18],[19,20],[21,22]).map(tensor)\n",
"dl = LMDataLoader(ints, bs=bs, seq_len=sl)\n",
"list(dl)\n",
"test_eq(list(dl),\n",
" [[tensor([[0, 1, 2], [5, 6, 7], [10, 11, 12], [15, 16, 17]]),\n",
" tensor([[1, 2, 3], [6, 7, 8], [11, 12, 13], [16, 17, 18]])],\n",
" [tensor([[3, 4], [8, 9], [13, 14], [18, 19]]),\n",
" tensor([[4, 5], [9, 10], [14, 15], [19, 20]])]])\n",
"\n",
"dl_new = dl.new()\n",
"test_eq(dl.one_batch(),dl_new.one_batch())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "env37",
"language": "python",
"name": "env37"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment