Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Created July 1, 2020 18:25
Show Gist options
  • Save epwalsh/fb04f7b03610a4bc248d7778cb8090a4 to your computer and use it in GitHub Desktop.
Save epwalsh/fb04f7b03610a4bc248d7778cb8090a4 to your computer and use it in GitHub Desktop.
Dataset Reader API
"""
Proposal for new DatasetReader API.
For this to work, all `Instance`s would have to be efficiently serializable.
So `TextField`s, for example, shouldn't contain `TokenIndexer`s.
The flow of data would look like this (boxes represent separate Python processes):
```
+-----------------------------------------------+
| |
| read_raw_data(path) -> T1, T2, T3, T4, ... + | This runs in a single worker process
| | |
+-----------------------------------------------+
| | |
| deserialize_raw_data(Ti) -> {...} <-------+ | This step runs in an arbitrary number
| + | of worker processes. The raw data
| +-----------+ | instances T1, T2, T3, ... are sent through
| v | a queue to this worker pool.
| make_instance(**{...}) -> Instance +-------+ |
| | | Where the instances go after this depends
+-----------------------------------------------+ on whether we have already built the vocab
| or not.
+------------------+
|
|
|
+-----------+-----------+
Training | | Building vocab
(vocab already built) | |
| |
+------------------------+---------+ +---------+--------------------+
| | | |
| Instance.index_fields(vocab) | | Vocab.from_instances(...) | Runs in the main process.
| | | |
| Instance.as_tensor_dict() --+ | +------------------------------+
| | |
+----------------------------------+
|
+----------------------------------+
| +---------------------------+ |
| | |
| +-> .... collects N tensor |
| dicts, then does sorting, |
| sampling, and batching. |
| Finalling sending the |
| batches on to the trainer |
| running in the main |
| process. |
| |
+----------------------------------+
```
"""
class DatasetReader(Registrable):
def read_raw_data(self, file_path: str) -> Iterable[T]:
"""
Must be implemented by all subclasses.
This is essentially what was called `_read` before, but instead of generating
`Instance`s, it generates the rawest form of the data.
Unlike `_read()`, it should do as little worker as possible.
All preprocessing or expensive deserialization should be done in either
`deserialize_raw_data` or `make_instance`.
The raw objects generated by this function will be passed on to `deserialize_raw_data()`
in a seperate worker process, and then through `make_instance()`.
"""
raise NotImplementedError
def deserialize_raw_data(self, raw: T) -> JsonDict:
"""
Must be implemented by all subclasses.
Should take a raw data object generated by `read_raw_data` and return a
`JsonDict` that will then be passed to `make_instance`.
"""
raise NotImplementedError
def make_instance(self, *args, **kwargs) -> Instance:
"""
Must be implemented by all subclasses.
This is really the same as `text_to_instance`, or could possibly
still be called that. I just don't love the naming of `text_to_instance`
since we are talking about expanding our domain to images.
"""
raise NotImplementedError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment