Skip to content

Instantly share code, notes, and snippets.

@rmax
Last active May 3, 2018 13:51
Show Gist options
  • Save rmax/f4e912aadc8447463209d4fe672c260c to your computer and use it in GitHub Desktop.
Save rmax/f4e912aadc8447463209d4fe672c260c to your computer and use it in GitHub Desktop.
An Elasticsearch reader for Dask
from dask import delayed
from elasticsearch import Elasticsearch
from elasticsearch.helpers import scan
def read_elasticsearch(query=None, npartitions=8, client_cls=None,
client_kwargs=None, **kwargs):
"""Reads documents from Elasticsearch.
By default, documents are sorted by ``_doc``. For more information see the
scrolling section in Elasticsearch documentation.
Parameters
----------
query : dict, optional
Search query.
npartitions : int, optional
Number of partitions, default is 8.
client_cls : elasticsearch.Elasticsearch, optional
Elasticsearch client class.
client_kwargs : dict, optional
Elasticsearch client parameters.
**params
Additional keyword arguments are passed to the the
``elasticsearch.helpers.scan`` function.
Returns
-------
out : List[Delayed]
A list of ``dask.Delayed`` objects.
Examples
--------
Get all documents in elasticsearch.
>>> docs = dask.bag.from_delayed(read_elasticsearch())
Get documents matching a given query.
>>> query = {"query": {"match_all": {}}}
>>> docs = dask.bag.from_delayed(read_elasticsearch(query, index="myindex", doc_type="stuff"))
"""
query = query or {}
# Sorting by _doc is preferred for scrolling.
query.setdefault('sort', ['_doc'])
if client_cls is None:
client_cls = Elasticsearch
values = []
for idx in range(npartitions):
slice = {'id': idx, 'max': npartitions}
scan_kwargs = dict(kwargs, query=dict(query, slice=slice))
values.append(
delayed(_elasticsearch_scan)(client_cls, client_kwargs, **scan_kwargs)
)
return values
def _elasticsearch_scan(client_cls, client_kwargs, **params):
# This method is executed in the worker's process and here we instantiate
# the ES client as it cannot be serialized.
client = client_cls(**(client_kwargs or {}))
return list(scan(client, **params))
In [2]: import dask.bag
In [3]: from dask_elasticsearch import read_elasticsearch
In [4]: docs = dask.bag.from_delayed(read_elasticsearch(index="myindex"))
In [5]: from dask.diagnostics import progress
In [6]: progress.ProgressBar().register()
In [7]: docs
dask.bag<bag-fro..., npartitions=8>
In [8]: docs.count().compute()
[########################################] | 100% Completed | 0.4s
346
In [9]: docs.map_partitions(len).compute()
[########################################] | 100% Completed | 0.3s
(46, 30, 33, 71, 66, 30, 33, 37)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment