Skip to content

Instantly share code, notes, and snippets.

@ohmeow
Created March 8, 2019 01:01
Show Gist options
  • Save ohmeow/fe91aed6267cd779946ab9f10eccdab9 to your computer and use it in GitHub Desktop.
Save ohmeow/fe91aed6267cd779946ab9f10eccdab9 to your computer and use it in GitHub Desktop.
How to prepare datasets/dataloaders for seq2seq models using the fast.ai DataBlock API
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fastai version: 1.0.47.post1\n"
]
}
],
"source": [
"print(f'fastai version: {__version__}')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/wgilliam/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n",
" PosixPath('/home/wgilliam/.fastai/data/giga-fren/giga-fren.release2.fixed.en')]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"PATH = Path('data/translate')\n",
"PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"DATA_PATH = untar_data(URLs.MT_ENG_FRA)\n",
"DATA_PATH.ls()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"folder = 'giga-fren.release2.fixed'\n",
"en_folder = DATA_PATH/f'{folder}.en'\n",
"fr_folder = DATA_PATH/f'{folder}.fr'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare data\n",
"\n",
"*only need to run through the section 1x to build the .csv file*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"re_eq = re.compile('^(Wh[^?.!]+\\?)')\n",
"re_fq = re.compile('^([^?.!]+\\?)')\n",
"\n",
"lines = ((re_eq.search(eq), re_fq.search(fq)) \n",
" for eq, fq in zip(open(en_folder, encoding='utf-8'), open(fr_folder, encoding='utf-8')))\n",
"\n",
"qs = [ {'english_text': e.group(), 'french_text': f.group()} for e, f in lines if e and f ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"qs[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(qs)\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.to_csv(PATH/'english_french_translate.csv', index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare data for training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"52331\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>english_text</th>\n",
" <th>french_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>What is light ?</td>\n",
" <td>Qu’est-ce que la lumière?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Who are we?</td>\n",
" <td>Où sommes-nous?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Where did we come from?</td>\n",
" <td>D'où venons-nous?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>What would we do without it?</td>\n",
" <td>Que ferions-nous sans elle ?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>What is the absolute location (latitude and lo...</td>\n",
" <td>Quelle sont les coordonnées (latitude et longi...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" english_text \\\n",
"0 What is light ? \n",
"1 Who are we? \n",
"2 Where did we come from? \n",
"3 What would we do without it? \n",
"4 What is the absolute location (latitude and lo... \n",
"\n",
" french_text \n",
"0 Qu’est-ce que la lumière? \n",
"1 Où sommes-nous? \n",
"2 D'où venons-nous? \n",
"3 Que ferions-nous sans elle ? \n",
"4 Quelle sont les coordonnées (latitude et longi... "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df = pd.read_csv(PATH/'english_french_translate.csv')\n",
"print(len(df))\n",
"display(df.head())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(41865, 10466)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df = df.sample(frac=0.8, random_state=42)\n",
"valid_df = df.iloc[~df.index.isin(train_df.index)]\n",
"len(train_df), len(valid_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Custom DataBlock API code"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def seq2seq_pad_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=False, \n",
" backwards:bool=False) -> Tuple[LongTensor, LongTensor]:\n",
" \n",
" \"Function that collect samples and adds padding. Flips token order if needed\"\n",
" \n",
" samples = to_data(samples)\n",
" x_max_len = max([len(s[0]) for s in samples])\n",
" y_max_len = max([len(s[1]) for s in samples])\n",
" \n",
" x_res = torch.zeros(len(samples), x_max_len).long() + pad_idx\n",
" y_res = torch.zeros(len(samples), y_max_len).long() + pad_idx\n",
" \n",
" if backwards: pad_first = not pad_first\n",
" \n",
" for i,s in enumerate(samples):\n",
" if pad_first: \n",
" x_res[i,-len(s[0]):] = LongTensor(s[0])\n",
" y_res[i,-len(s[1]):] = LongTensor(s[1])\n",
" else: \n",
" x_res[i,:len(s[0]):] = LongTensor(s[0])\n",
" y_res[i,:len(s[1]):] = LongTensor(s[1])\n",
" \n",
" if backwards: res = res.flip(1)\n",
" \n",
" return x_res, y_res"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqDataBunch(DataBunch):\n",
"\n",
" @classmethod\n",
" def create(cls, train_ds, valid_ds, test_ds=None, \n",
" path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1, pad_first=False, \n",
" device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n",
" \n",
" \"\"\"Function that transform the `datasets` in a `DataBunch` for classification. \n",
" Passes `**dl_kwargs` on to `DataLoader()`\"\"\"\n",
" \n",
" datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n",
" val_bs = ifnone(val_bs, bs)\n",
" \n",
" collate_fn = partial(seq2seq_pad_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n",
" \n",
" train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n",
" train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n",
" \n",
" dataloaders = [train_dl]\n",
" for ds in datasets[1:]:\n",
" lengths = [len(t) for t in ds.x.items]\n",
" sampler = SortSampler(ds.x, key=lengths.__getitem__)\n",
" dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n",
" return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqTextList(TextList):\n",
" _bunch = Seq2SeqDataBunch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build our DataBunch"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"bs = 64\n",
"val_bs = 128"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"en_tok = Tokenizer(lang='en')\n",
"fr_tok = Tokenizer(lang='fr')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"en_procs = [TokenizeProcessor(tokenizer=en_tok), NumericalizeProcessor()]\n",
"fr_procs = [TokenizeProcessor(tokenizer=fr_tok), NumericalizeProcessor()]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"en_train_il = Seq2SeqTextList.from_df(train_df, path=PATH, cols=['english_text'], processor=en_procs)\n",
"fr_train_il = Seq2SeqTextList.from_df(train_df, path=PATH, cols=['french_text'], processor=fr_procs)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"en_valid_il = Seq2SeqTextList.from_df(valid_df, path=PATH, cols=['english_text'])\n",
"fr_valid_il = Seq2SeqTextList.from_df(valid_df, path=PATH, cols=['french_text'])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"trn_ll = LabelList(en_train_il, fr_train_il)\n",
"val_ll = LabelList(en_valid_il, fr_valid_il)\n",
"\n",
"lls = LabelLists(PATH, train=trn_ll, valid=val_ll).process()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(8174, 8174)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(lls.train.vocab.itos), len(lls.valid.vocab.itos)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating a DataBunch via the API"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"data = lls.databunch(bs=bs, val_bs=val_bs)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 200]), torch.Size([64, 198]))"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b = next(iter(data.train_dl))\n",
"b[0].shape, b[1].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating DataBunch via DataLoaders"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"train_sampler = SortishSampler(lls.train.x, key=lambda t: len(lls.train[t][0].data), bs=bs//2)\n",
"valid_sampler = SortSampler(lls.valid.x, key=lambda t: len(lls.valid[t][0].data))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"train_dl = DataLoader(lls.train, batch_size=bs, sampler=train_sampler, drop_last=True)\n",
"valid_dl = DataLoader(lls.valid, batch_size=val_bs, sampler=valid_sampler)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"data = DataBunch(train_dl=train_dl, valid_dl=valid_dl, collate_fn=seq2seq_pad_collate)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 200]), torch.Size([64, 198]))"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b = next(iter(data.train_dl))\n",
"b[0].shape, b[1].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@akhavr
Copy link

akhavr commented Mar 8, 2019

if backwards: res = res.flip(1)

res here is used only once. Are you sure about this line?

@ohmeow
Copy link
Author

ohmeow commented Mar 8, 2019

That is just copied over from the framework code ... not sure how necessary it is here.

@bfarzin
Copy link

bfarzin commented May 23, 2019

This was very helpful. Thanks for posting it!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment