Skip to content

Instantly share code, notes, and snippets.

Created May 18, 2018 06:47
Show Gist options
  • Save fcracker79/8fe28dfb71f616ea9ac26dec4ac1b95c to your computer and use it in GitHub Desktop.
Save fcracker79/8fe28dfb71f616ea9ac26dec4ac1b95c to your computer and use it in GitHub Desktop.
import itertools
import typing
from pprint import pprint
def full_join(d1: typing.Optional[dict], d2: typing.Optional[dict]) -> typing.Optional[dict]:
if d1 is None or d2 is None:
return None
return {k: d1.get(k, d2.get(k)) for k in itertools.chain(d1, d2)}
def left_join(d1: typing.Optional[dict], d2: typing.Optional[dict]) -> typing.Optional[dict]:
if d1 is None:
return None
if d2 is None:
return d1
return full_join(d1, d2)
def aggregate(
g: typing.Iterator[dict],
extract_key: typing.Callable[[typing.Dict], str]) -> typing.Iterator[typing.List[dict]]:
last_values = last_key = []
for e in g:
if not last_values:
last_values, last_key = [e], extract_key(e)
elif last_key == extract_key(e):
yield last_values
last_values, last_key = [e], extract_key(e)
if last_values:
yield last_values
def join(
g1: typing.Iterator[dict],
g2: typing.Iterator[dict],
join_elements: typing.Callable[[typing.Optional[dict], typing.Optional[dict]], typing.Optional[dict]],
extract_key: typing.Callable[[typing.Dict], str]) -> typing.Iterator[dict]:
def yield_if_not_none(d1: typing.Optional[dict], d2: typing.Optional[dict]):
result = join_elements(d1, d2)
if result:
yield result
i1, i2 = iter(g1), iter(g2)
o1 = o2 = None
while True:
if not o1:
o1 = next(i1)
except StopIteration:
if not o2:
o2 = next(i2)
except StopIteration:
o2 = None
if o2 is None:
yield from yield_if_not_none(o1, None)
k1, k2 = extract_key(o1), extract_key(o2)
if k1 < k2:
yield from yield_if_not_none(o1, None)
o1 = None
elif k1 > k2:
yield from yield_if_not_none(None, o2)
o2 = None
yield from yield_if_not_none(o1, o2)
o1 = o2 = None
if __name__ == '__main__':
u1 = ({'user_id': 'user_{:02}'.format(i), 'field1': i} for i in range(10))
u2 = ({'user_id': 'user_{:02}'.format(i), 'field2': i} for i in range(1, 13, 2))
u3 = ({'user_id': 'user_{:02}'.format(i), 'field1': i, 'fieldX': i // 3} for i in range(10))
pprint(list(join(u1, u2, left_join, lambda d: d['user_id'])))
pprint(list(aggregate(u3, lambda d: d['fieldX'])))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment