Skip to content

Instantly share code, notes, and snippets.

@rmax
Last active September 17, 2018 19:28
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rmax/3eb3d098167e9ac9f3d7ce923569aad0 to your computer and use it in GitHub Desktop.
Save rmax/3eb3d098167e9ac9f3d7ce923569aad0 to your computer and use it in GitHub Desktop.
An Avro reader for Dask (with fastavro)
"""A fastavro-based avro reader for Dask.
Disclaimer: This code was recovered from dask's distributed project.
"""
import io
import fastavro
import json
from dask import delayed
from dask.bytes import read_bytes
from dask.bytes.core import OpenFileCreator
def read_avro(urlpath, blocksize=2**27, **kwargs):
"""Reads avro files.
Parameters
----------
urlpath : string
Absolute or relative filepath, URL or globstring pointing to avro files.
blocksize : int, optional
Size of blocks. Default is 128MB.
**kwargs
Additional arguments passed to ``dask.bytes.read_bytes`` function.
Returns
-------
out : list[Delayed]
A list of delayed objects, one for each block.
"""
myopen = OpenFileCreator(urlpath)
values = []
for fn in myopen.fs.glob(urlpath):
with myopen(fn) as fp:
av = fastavro.reader(fp)
header = av._header
_, blockss = read_bytes(fn, delimiter=header['sync'], not_zero=True,
sample=False, blocksize=blocksize, **kwargs)
values.extend(
delayed(_avro_body)(block, header) for blocks in blockss for block in blocks
)
if not values:
raise ValueError("urlpath is empty: %s" % urlpath)
return values
def _avro_body(data, header):
"""Returns records for given avro data fragment."""
stream = io.BytesIO(data)
schema = json.loads(header['meta']['avro.schema'].decode())
codec = header['meta']['avro.codec'].decode()
return iter(fastavro._reader._iter_avro(stream, header, codec, schema, schema))
import dask.bag
from dask_avro import read_avro
# Download test file from https://github.com/tebeka/fastavro/blob/master/tests/avro-files/test-snappy.avro
data = dask.bag.from_delayed(read_avro('test-snappy.avro', blocksize=1024))
print("Partitions: %s" % data.npartitions)
print("Records: %s" % data.count().compute())
records = data.take(5)
print("Sample:\n %s" % "\n ".join(repr(obj) for obj in records))
Partitions: 14
Records: 400
Sample:
{'stringField': 'bcoopyccxlnvddvstmcjg', 'longField': -2244928484433289874}
{'stringField': 'xqrrmdsopdrliytknbalvjqrdgilrxhubly', 'longField': 3324743337320530674}
{'stringField': 'urndclybnuexuhwnp', 'longField': -4845990395235763935}
{'stringField': 'kxixmprjp', 'longField': 987436882195617773}
{'stringField': 'noywtcprnrpfvbbhvfmqtqijwigkunyyttwx', 'longField': 6800041140242693723}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment