Skip to content

Instantly share code, notes, and snippets.

Created July 1, 2020 18:25
Show Gist options
  • Save epwalsh/fb04f7b03610a4bc248d7778cb8090a4 to your computer and use it in GitHub Desktop.
Save epwalsh/fb04f7b03610a4bc248d7778cb8090a4 to your computer and use it in GitHub Desktop.
Dataset Reader API
Proposal for new DatasetReader API.
For this to work, all `Instance`s would have to be efficiently serializable.
So `TextField`s, for example, shouldn't contain `TokenIndexer`s.
The flow of data would look like this (boxes represent separate Python processes):
| |
| read_raw_data(path) -> T1, T2, T3, T4, ... + | This runs in a single worker process
| | |
| | |
| deserialize_raw_data(Ti) -> {...} <-------+ | This step runs in an arbitrary number
| + | of worker processes. The raw data
| +-----------+ | instances T1, T2, T3, ... are sent through
| v | a queue to this worker pool.
| make_instance(**{...}) -> Instance +-------+ |
| | | Where the instances go after this depends
+-----------------------------------------------+ on whether we have already built the vocab
| or not.
Training | | Building vocab
(vocab already built) | |
| |
+------------------------+---------+ +---------+--------------------+
| | | |
| Instance.index_fields(vocab) | | Vocab.from_instances(...) | Runs in the main process.
| | | |
| Instance.as_tensor_dict() --+ | +------------------------------+
| | |
| +---------------------------+ |
| | |
| +-> .... collects N tensor |
| dicts, then does sorting, |
| sampling, and batching. |
| Finalling sending the |
| batches on to the trainer |
| running in the main |
| process. |
| |
class DatasetReader(Registrable):
def read_raw_data(self, file_path: str) -> Iterable[T]:
Must be implemented by all subclasses.
This is essentially what was called `_read` before, but instead of generating
`Instance`s, it generates the rawest form of the data.
Unlike `_read()`, it should do as little worker as possible.
All preprocessing or expensive deserialization should be done in either
`deserialize_raw_data` or `make_instance`.
The raw objects generated by this function will be passed on to `deserialize_raw_data()`
in a seperate worker process, and then through `make_instance()`.
raise NotImplementedError
def deserialize_raw_data(self, raw: T) -> JsonDict:
Must be implemented by all subclasses.
Should take a raw data object generated by `read_raw_data` and return a
`JsonDict` that will then be passed to `make_instance`.
raise NotImplementedError
def make_instance(self, *args, **kwargs) -> Instance:
Must be implemented by all subclasses.
This is really the same as `text_to_instance`, or could possibly
still be called that. I just don't love the naming of `text_to_instance`
since we are talking about expanding our domain to images.
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment