Skip to content

Instantly share code, notes, and snippets.

@sai-prasanna
Last active February 26, 2020 09:25
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save sai-prasanna/4562d73146af8b7a55b4b9d96da5a9a3 to your computer and use it in GitHub Desktop.
Save sai-prasanna/4562d73146af8b7a55b4b9d96da5a9a3 to your computer and use it in GitHub Desktop.
Multiprocess seq2seq reader using pytorch Dataloader, Dataset.
import csv
from typing import Dict, Optional
import logging
import torch
import random
from collections import Counter
import numpy as np
from overrides import overrides
from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler
from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
logger = logging.getLogger(__name__)
@DatasetReader.register("seq2seq")
class Seq2SeqDatasetReader(DatasetReader):
"""
Read a tsv file containing paired sequences, and create a dataset suitable for a
``ComposedSeq2Seq`` model, or any model with a matching API.
Expected format for each input line: <source_sequence_string>\t<target_sequence_string>
The output of ``read`` is a list of ``Instance`` s with the fields:
source_tokens: ``TextField`` and
target_tokens: ``TextField``
`START_SYMBOL` and `END_SYMBOL` tokens are added to the source and target sequences.
Parameters
----------
source_tokenizer : ``Tokenizer``, optional
Tokenizer to use to split the input sequences into words or other kinds of tokens. Defaults
to ``WordTokenizer()``.
target_tokenizer : ``Tokenizer``, optional
Tokenizer to use to split the output sequences (during training) into words or other kinds
of tokens. Defaults to ``source_tokenizer``.
source_token_indexers : ``Dict[str, TokenIndexer]``, optional
Indexers used to define input (source side) token representations. Defaults to
``{"tokens": SingleIdTokenIndexer()}``.
target_token_indexers : ``Dict[str, TokenIndexer]``, optional
Indexers used to define output (target side) token representations. Defaults to
``source_token_indexers``.
source_add_start_token : bool, (optional, default=True)
Whether or not to add `START_SYMBOL` to the beginning of the source sequence.
delimiter : str, (optional, default="\t")
Set delimiter for tsv/csv file.
"""
def __init__(
self,
source_tokenizer: Tokenizer = None,
target_tokenizer: Tokenizer = None,
source_token_indexers: Dict[str, TokenIndexer] = None,
target_token_indexers: Dict[str, TokenIndexer] = None,
source_add_start_token: bool = True,
delimiter: str = "\t",
source_max_tokens: Optional[int] = None,
target_max_tokens: Optional[int] = None,
lazy: bool = False,
) -> None:
super().__init__(lazy)
self._source_tokenizer = source_tokenizer or WordTokenizer()
self._target_tokenizer = target_tokenizer or self._source_tokenizer
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()}
self._target_token_indexers = target_token_indexers or self._source_token_indexers
self._source_add_start_token = source_add_start_token
self._delimiter = delimiter
self._source_max_tokens = source_max_tokens
self._target_max_tokens = target_max_tokens
self._source_max_exceeded = 0
self._target_max_exceeded = 0
self._epoch_counter = Counter()
self._initial_seed = 1337
def _raw_dataset(self, file_path):
paired_sequences = []
with open(cached_path(file_path), "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
for line_num, row in enumerate(csv.reader(data_file, delimiter=self._delimiter)):
if len(row) != 2:
continue
source_sequence, target_sequence = row
paired_sequences.append(row)
np.random.RandomState(self._initial_seed + self._epoch_counter[file_path]).shuffle(paired_sequences)
return paired_sequences
def _to_instance(self, raw_data_item):
return self.text_to_instance(*raw_data_item)
@overrides
def _read(self, file_path):
raw_dataset = self._raw_dataset(file_path)
instancizer = self
dataset = _DatasetWrapper(raw_dataset, instancizer)
if torch.distributed.is_available():
sampler = DistributedSampler(dataset)
else:
sampler = DistributedSampler(dataset, 1, 0)
loader = DataLoader(dataset, batch_size=100, num_workers=2, sampler=sampler, collate_fn=identity)
for instances in loader:
for instance in instances:
instance["source_tokens"]._token_indexers = self._source_token_indexers
instance["target_tokens"]._token_indexers = self._target_token_indexers
yield instance
self._epoch_counter[file_path] += 1
@overrides
def text_to_instance(
self, source_string: str, target_string: str = None
) -> Instance: # type: ignore
tokenized_source = self._source_tokenizer.tokenize(source_string)
if self._source_max_tokens and len(tokenized_source) > self._source_max_tokens:
self._source_max_exceeded += 1
tokenized_source = tokenized_source[: self._source_max_tokens]
if self._source_add_start_token:
tokenized_source.insert(0, Token(START_SYMBOL))
tokenized_source.append(Token(END_SYMBOL))
source_field = TextField(tokenized_source, self._source_token_indexers)
if target_string is not None:
tokenized_target = self._target_tokenizer.tokenize(target_string)
if self._target_max_tokens and len(tokenized_target) > self._target_max_tokens:
self._target_max_exceeded += 1
tokenized_target = tokenized_target[: self._target_max_tokens]
tokenized_target.insert(0, Token(START_SYMBOL))
tokenized_target.append(Token(END_SYMBOL))
target_field = TextField(tokenized_target, self._target_token_indexers)
return Instance({"source_tokens": source_field, "target_tokens": target_field})
else:
return Instance({"source_tokens": source_field})
class _DatasetWrapper(Dataset):
def __init__(self, raw_dataset, instancizer):
self._raw_dataset = raw_dataset
self.instancizer = instancizer
def __getitem__(self, index):
return self.instancizer._to_instance(self._raw_dataset[index])
def __len__(self):
return len(self._raw_dataset)
def identity(x):
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment