Created
August 27, 2021 18:08
-
-
Save Kache/9b8f3328321f3e17fb4d3ef1b1326273 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import hashlib as hl | |
import inspect | |
import io | |
from queue import Queue | |
from threading import Thread | |
from typing import Any, AnyStr, Iterable, TypeVar, TYPE_CHECKING, Iterator | |
import zlib | |
if TYPE_CHECKING: | |
from _typeshed import SupportsRead | |
AnyStr2 = TypeVar('AnyStr2', str, bytes) | |
def heredoc(s): | |
""" | |
Useful for representing text file contents. | |
Example: | |
csv_str = heredoc(''' | |
abc,def | |
123,456 | |
''') | |
csv_str == 'abc,def\n123,456\n' # https://stackoverflow.com/a/729795/234593 | |
""" | |
return inspect.cleandoc(s) + '\n' | |
def as_bytes(s: str | bytes): | |
if type(s) not in [str, bytes]: | |
raise TypeError | |
return s.encode() if isinstance(s, str) else s | |
def sha2(s: str | bytes): | |
return hl.sha256(as_bytes(s)).hexdigest() | |
def dig(collection, *keys): | |
""" | |
"Dig" into nested dictionaries and lists. | |
None-aware, useful for JSON. | |
Examples: | |
nested_dict = {'a': {'b': 1}, 'c': [2, 3]} | |
assert dig(nested_dict, 'a', 'b') == 1 | |
assert dig(nested_dict, 'a', 'x', 'y') is None | |
nested_list = [{'a': 3}, {'a': 4, 'b': 5}] | |
assert dig(nested_list, 0, 'a') == 3 | |
assert dig(nested_list, 0, 'b') is None | |
See tests for more examples | |
""" | |
curr = collection | |
for k in keys: | |
if curr is None: | |
break | |
if not hasattr(curr, '__getitem__') or isinstance(curr, str): | |
raise TypeError(f'cannot dig into {type(curr)}') | |
try: | |
curr = curr[k] | |
except (KeyError, IndexError): | |
curr = None | |
return curr | |
def iter_io(iterable: Iterable[AnyStr], buffer_size: int = io.DEFAULT_BUFFER_SIZE): | |
""" | |
Returns a buffered file obj that reads bytes from an iterable of str/bytes. | |
Example: | |
iter_io(['abc', 'def', 'g']).read() == b'abcdefg' | |
iter_io([b'abcd', b'efg']).read(5) == b'abcde' | |
""" | |
class IterIO(io.RawIOBase): | |
def __init__(self, iterable: Iterable[AnyStr]): | |
self._leftover = b'' | |
self._iterable = (as_bytes(s) for s in iterable if s) | |
def readable(self): | |
return True | |
def readinto(self, buf): | |
try: | |
chunk = self._leftover or next(self._iterable) | |
except StopIteration: | |
return 0 # indicate EOF | |
output, self._leftover = chunk[:len(buf)], chunk[len(buf):] | |
buf[:len(output)] = output | |
return len(output) | |
return io.BufferedReader(IterIO(iterable), buffer_size=buffer_size) | |
def io_iter(fo: SupportsRead[AnyStr], size: int = io.DEFAULT_BUFFER_SIZE): | |
""" | |
Returns an iterator that reads from a file obj in sized chunks. | |
Example: | |
list(io_iter(io.StringIO('abcdefg'), 3)) == ['abc', 'def', 'g'] | |
list(io_iter(io.BytesIO(b'abcdefg'), 4)) == [b'abcd', b'efg'] | |
Usage notes/TODO: | |
* file obj isn't closed, fix /w keep_open=False and an internal contextmanager | |
""" | |
return iter(lambda: fo.read(size), fo.read(0)) | |
def igzip(chunks: Iterable[AnyStr], level=zlib.Z_DEFAULT_COMPRESSION): | |
""" | |
Streaming gzip: lazily compresses an iterable of bytes or str (utf8) | |
Example: | |
gzipped_bytes_iter = igzip(['hello ', 'world!']) | |
gzip.decompress(b''.join(gzipped_bytes_iter)).encode() == 'hello world!' | |
""" | |
def gen(): | |
gzip_format = 0b10000 | |
c = zlib.compressobj(level=level, wbits=zlib.MAX_WBITS + gzip_format) | |
yield from (c.compress(as_bytes(chunk)) for chunk in chunks) | |
yield c.flush() | |
return filter(None, gen()) | |
def prefetch(iterable: Iterable[Any], n: int = 1) -> Iterator[Any]: | |
""" | |
Prefetch an iterable via thread, yielding original contents as normal. | |
Example: | |
def slow_produce(*args): | |
for x in args: | |
time.sleep(1) | |
yield x | |
def slow_consume(iterable): | |
for _ in iterable: | |
time.sleep(1) | |
slow_consume(prefetch(slow_produce('a', 'b'))) # takes 3 sec, not 4 | |
# Prefetch | |
# produce: | 'a' | 'b' | | |
# consume: | 'a' | 'b' | | |
# seconds: 0 --- 1 --- 2 --- 3 | |
# No prefetch | |
# produce: | 'a' | | 'b' | | |
# consume: | 'a' | | 'b' | | |
# seconds: 0 --- 1 --- 2 --- 3 --- 4 | |
Usage notes/TODO: | |
* mem leak: Thread is GC'd only after iterable is fully consumed, fix /w __del__ | |
""" | |
queue = Queue(n) | |
finished = object() | |
def produce(): | |
for x in iterable: | |
queue.put(x) | |
queue.put(finished) | |
t = Thread(target=produce, daemon=True) | |
t.start() | |
while True: | |
item = queue.get() | |
if item is finished: | |
break | |
else: | |
yield item |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment