-
-
Save ohmeow/fe91aed6267cd779946ab9f10eccdab9 to your computer and use it in GitHub Desktop.
How to prepare datasets/dataloaders for seq2seq models using the fast.ai DataBlock API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from fastai.text import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"fastai version: 1.0.47.post1\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f'fastai version: {__version__}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[PosixPath('/home/wgilliam/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n", | |
" PosixPath('/home/wgilliam/.fastai/data/giga-fren/giga-fren.release2.fixed.en')]" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"PATH = Path('data/translate')\n", | |
"PATH.mkdir(parents=True, exist_ok=True)\n", | |
"\n", | |
"DATA_PATH = untar_data(URLs.MT_ENG_FRA)\n", | |
"DATA_PATH.ls()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"folder = 'giga-fren.release2.fixed'\n", | |
"en_folder = DATA_PATH/f'{folder}.en'\n", | |
"fr_folder = DATA_PATH/f'{folder}.fr'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Prepare data\n", | |
"\n", | |
"*only need to run through the section 1x to build the .csv file*" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"re_eq = re.compile('^(Wh[^?.!]+\\?)')\n", | |
"re_fq = re.compile('^([^?.!]+\\?)')\n", | |
"\n", | |
"lines = ((re_eq.search(eq), re_fq.search(fq)) \n", | |
" for eq, fq in zip(open(en_folder, encoding='utf-8'), open(fr_folder, encoding='utf-8')))\n", | |
"\n", | |
"qs = [ {'english_text': e.group(), 'french_text': f.group()} for e, f in lines if e and f ]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"qs[:5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"df = pd.DataFrame(qs)\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"df.to_csv(PATH/'english_french_translate.csv', index=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Prepare data for training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"52331\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>english_text</th>\n", | |
" <th>french_text</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>What is light ?</td>\n", | |
" <td>Qu’est-ce que la lumière?</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>Who are we?</td>\n", | |
" <td>Où sommes-nous?</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>Where did we come from?</td>\n", | |
" <td>D'où venons-nous?</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>What would we do without it?</td>\n", | |
" <td>Que ferions-nous sans elle ?</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>What is the absolute location (latitude and lo...</td>\n", | |
" <td>Quelle sont les coordonnées (latitude et longi...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" english_text \\\n", | |
"0 What is light ? \n", | |
"1 Who are we? \n", | |
"2 Where did we come from? \n", | |
"3 What would we do without it? \n", | |
"4 What is the absolute location (latitude and lo... \n", | |
"\n", | |
" french_text \n", | |
"0 Qu’est-ce que la lumière? \n", | |
"1 Où sommes-nous? \n", | |
"2 D'où venons-nous? \n", | |
"3 Que ferions-nous sans elle ? \n", | |
"4 Quelle sont les coordonnées (latitude et longi... " | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"df = pd.read_csv(PATH/'english_french_translate.csv')\n", | |
"print(len(df))\n", | |
"display(df.head())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(41865, 10466)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df = df.sample(frac=0.8, random_state=42)\n", | |
"valid_df = df.iloc[~df.index.isin(train_df.index)]\n", | |
"len(train_df), len(valid_df)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Custom DataBlock API code" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def seq2seq_pad_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=False, \n", | |
" backwards:bool=False) -> Tuple[LongTensor, LongTensor]:\n", | |
" \n", | |
" \"Function that collect samples and adds padding. Flips token order if needed\"\n", | |
" \n", | |
" samples = to_data(samples)\n", | |
" x_max_len = max([len(s[0]) for s in samples])\n", | |
" y_max_len = max([len(s[1]) for s in samples])\n", | |
" \n", | |
" x_res = torch.zeros(len(samples), x_max_len).long() + pad_idx\n", | |
" y_res = torch.zeros(len(samples), y_max_len).long() + pad_idx\n", | |
" \n", | |
" if backwards: pad_first = not pad_first\n", | |
" \n", | |
" for i,s in enumerate(samples):\n", | |
" if pad_first: \n", | |
" x_res[i,-len(s[0]):] = LongTensor(s[0])\n", | |
" y_res[i,-len(s[1]):] = LongTensor(s[1])\n", | |
" else: \n", | |
" x_res[i,:len(s[0]):] = LongTensor(s[0])\n", | |
" y_res[i,:len(s[1]):] = LongTensor(s[1])\n", | |
" \n", | |
" if backwards: res = res.flip(1)\n", | |
" \n", | |
" return x_res, y_res" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Seq2SeqDataBunch(DataBunch):\n", | |
"\n", | |
" @classmethod\n", | |
" def create(cls, train_ds, valid_ds, test_ds=None, \n", | |
" path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1, pad_first=False, \n", | |
" device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n", | |
" \n", | |
" \"\"\"Function that transform the `datasets` in a `DataBunch` for classification. \n", | |
" Passes `**dl_kwargs` on to `DataLoader()`\"\"\"\n", | |
" \n", | |
" datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n", | |
" val_bs = ifnone(val_bs, bs)\n", | |
" \n", | |
" collate_fn = partial(seq2seq_pad_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n", | |
" \n", | |
" train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n", | |
" train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n", | |
" \n", | |
" dataloaders = [train_dl]\n", | |
" for ds in datasets[1:]:\n", | |
" lengths = [len(t) for t in ds.x.items]\n", | |
" sampler = SortSampler(ds.x, key=lengths.__getitem__)\n", | |
" dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n", | |
" return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Seq2SeqTextList(TextList):\n", | |
" _bunch = Seq2SeqDataBunch" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Build our DataBunch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"bs = 64\n", | |
"val_bs = 128" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"en_tok = Tokenizer(lang='en')\n", | |
"fr_tok = Tokenizer(lang='fr')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"en_procs = [TokenizeProcessor(tokenizer=en_tok), NumericalizeProcessor()]\n", | |
"fr_procs = [TokenizeProcessor(tokenizer=fr_tok), NumericalizeProcessor()]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"en_train_il = Seq2SeqTextList.from_df(train_df, path=PATH, cols=['english_text'], processor=en_procs)\n", | |
"fr_train_il = Seq2SeqTextList.from_df(train_df, path=PATH, cols=['french_text'], processor=fr_procs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"en_valid_il = Seq2SeqTextList.from_df(valid_df, path=PATH, cols=['english_text'])\n", | |
"fr_valid_il = Seq2SeqTextList.from_df(valid_df, path=PATH, cols=['french_text'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trn_ll = LabelList(en_train_il, fr_train_il)\n", | |
"val_ll = LabelList(en_valid_il, fr_valid_il)\n", | |
"\n", | |
"lls = LabelLists(PATH, train=trn_ll, valid=val_ll).process()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(8174, 8174)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(lls.train.vocab.itos), len(lls.valid.vocab.itos)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Creating a DataBunch via the API" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = lls.databunch(bs=bs, val_bs=val_bs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([64, 200]), torch.Size([64, 198]))" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"b = next(iter(data.train_dl))\n", | |
"b[0].shape, b[1].shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Creating DataBunch via DataLoaders" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_sampler = SortishSampler(lls.train.x, key=lambda t: len(lls.train[t][0].data), bs=bs//2)\n", | |
"valid_sampler = SortSampler(lls.valid.x, key=lambda t: len(lls.valid[t][0].data))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_dl = DataLoader(lls.train, batch_size=bs, sampler=train_sampler, drop_last=True)\n", | |
"valid_dl = DataLoader(lls.valid, batch_size=val_bs, sampler=valid_sampler)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = DataBunch(train_dl=train_dl, valid_dl=valid_dl, collate_fn=seq2seq_pad_collate)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([64, 200]), torch.Size([64, 198]))" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"b = next(iter(data.train_dl))\n", | |
"b[0].shape, b[1].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
That is just copied over from the framework code ... not sure how necessary it is here.
This was very helpful. Thanks for posting it!!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
res
here is used only once. Are you sure about this line?