Skip to content

Instantly share code, notes, and snippets.

@anna-hope
Created May 30, 2019 18:36
Show Gist options
  • Save anna-hope/3c4d636f2b79e206b26acfe349f2657a to your computer and use it in GitHub Desktop.
Save anna-hope/3c4d636f2b79e206b26acfe349f2657a to your computer and use it in GitHub Desktop.
Torchtext dataset and iterator wrappers for Pandas DataFrames
from typing import Union, Dict
import pandas as pd
from torchtext.data import (Field, Example, Iterator, BucketIterator, Dataset)
from tqdm import tqdm
class DataFrameExampleSet:
def __init__(self, df: pd.DataFrame, fields: Dict[str, Field]):
self._df = df
self._fields = fields
self._fields_dict = {field_name: (field_name, field)
for field_name, field in fields.items()
if field is not None}
def __iter__(self):
for item in tqdm(self._df.itertuples(), total=len(self)):
example = Example.fromdict(item._asdict(), fields=self._fields_dict)
yield example
def __len__(self):
return len(self._df)
def shuffle(self, random_state=None):
self._df = self._df.sample(frac=1.0, random_state=random_state)
class DataFrameDataset(Dataset):
def __init__(self, df: pd.DataFrame,
fields: Dict[str, Field], filter_pred=None):
examples = DataFrameExampleSet(df, fields)
super().__init__(examples, fields, filter_pred=filter_pred)
class DataFrameIterator(Iterator):
def data(self):
if isinstance(self.dataset.examples, DataFrameExampleSet):
if self.shuffle:
self.dataset.examples.shuffle()
dataset = self.dataset
else:
dataset = super().data()
return dataset
class DataFrameBucketIterator(BucketIterator):
def data(self):
if isinstance(self.dataset.examples, DataFrameExampleSet):
if self.shuffle:
self.dataset.examples.shuffle()
dataset = self.dataset
else:
dataset = super().data()
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment