Skip to content

Instantly share code, notes, and snippets.

@hiromis
Created March 1, 2018 04:04
Show Gist options
  • Save hiromis/a9ee8a490a531567390e882adaf994a0 to your computer and use it in GitHub Desktop.
Save hiromis/a9ee8a490a531567390e882adaf994a0 to your computer and use it in GitHub Desktop.
Off by one error
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"import torchtext\n",
"\n",
"from torchtext import vocab, data\n",
"\n",
"from fastai.nlp import *\n",
"from fastai.lm_rnn import *\n",
"from fastai.learner import *\n",
"from fastai.column_data import *\n",
"from fastai.hiromi import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"PATH = 'data/toxic/'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_df = pd.read_csv(f'{PATH}train.csv')\n",
"test_df = pd.read_csv(f'{PATH}test.csv')\n",
"sample_submit_df = pd.read_csv(f'{PATH}sample_submission.csv')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# =========== SAMPLE ===================\n",
"train_df = train_df[:1000]\n",
"val_df = train_df[:200]\n",
"test_df = test_df[:200]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"bs=64\n",
"bptt=70"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"TEXT = data.Field(tokenize=spacy_tok)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"FILES = dict(train_df=train_df, val_df=test_df, test_df=test_df)\n",
"md = LanguageModelData.from_dataframes(PATH, TEXT, 'comment_text', **FILES, bs=bs, bptt=bptt, min_freq=10)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"LABEL = data.Field(sequential=False)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class MyDataset(torchtext.data.Dataset):\n",
" def __init__(self, df, text_field, label_field, is_test=False, **kwargs):\n",
" fields = [('text', text_field), ('label', label_field)]\n",
" examples = []\n",
" for i, row in df.iterrows():\n",
" label = 'pos'\n",
" if not is_test and row['toxic']==0:\n",
" label = 'neg' \n",
" text = row['comment_text']\n",
" examples.append(torchtext.data.Example.fromlist([text, label], fields))\n",
"\n",
" super().__init__(examples, fields, **kwargs)\n",
"\n",
" @staticmethod\n",
" def sort_key(ex): return len(ex.text)\n",
" \n",
" @classmethod\n",
" def splits(cls, text_field, label_field, train_df, val_df=None, test_df=None, **kwargs):\n",
" train_data, val_data, test_data = (None, None, None)\n",
"\n",
" if train_df is not None:\n",
" train_data = cls(train_df.copy(), text_field, label_field, **kwargs)\n",
" if val_df is not None:\n",
" val_data = cls(val_df.copy(), text_field, label_field, **kwargs)\n",
" if test_df is not None:\n",
" test_data = cls(test_df.copy(), text_field, label_field, True, **kwargs)\n",
"\n",
" return tuple(d for d in (train_data, val_data, test_data) if d is not None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Before the change:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"splits = MyDataset.splits(TEXT, LABEL, train_df=train_df, val_df=val_df, test_df=test_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have 200 test data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"200"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(test_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Really, it's 200!"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"200"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(splits[1].examples)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"bs=10"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"md = TextData.from_splits(PATH, splits, bs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But we only get 19 batches (190 rows)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"19"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(md.test_dl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## After the change:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"splits = MyDataset.splits(TEXT, LABEL, train_df=train_df, val_df=val_df, test_df=test_df)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"md = TextData.from_splits(PATH, splits, bs)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"20"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(md.test_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/fc03e2317dc739c34c846ef410fafb72"
},
"gist": {
"data": {
"description": "Bug description",
"public": true
},
"id": "fc03e2317dc739c34c846ef410fafb72"
},
"kernelspec": {
"display_name": "fastai",
"language": "python",
"name": "fastai"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment