Skip to content

Instantly share code, notes, and snippets.

@moradology
Created February 8, 2024 21:38

Revisions

  1. moradology created this gist Feb 8, 2024.
    76 changes: 76 additions & 0 deletions wtf.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,76 @@
    #!/usr/bin/env python
    import json
    import multiprocessing
    from typing import Dict, List, Optional, Sequence, Tuple

    from kerchunk.combine import MultiZarrToZarr

    CONCAT_DIMS = ['time']
    IDENTICAL_DIMS = ['lat', 'lon']

    def load_refs(ldjson_file: str) -> list[dict]:
    refs = []
    with open(ldjson_file, "r") as f:
    for line in f:
    refs.append(json.loads(line)[0])
    return refs

    def mzz(refs):
    return MultiZarrToZarr(
    refs,
    concat_dims=CONCAT_DIMS,
    identical_dims=IDENTICAL_DIMS,
    target_options={"anon": True},
    remote_options={"anon": True},
    remote_protocol=None
    )

    def merge_refs(refs: list[dict]) -> dict:
    return mzz(refs).translate()

    # Distributed workflow
    def worker_func(refs: list[dict]) -> MultiZarrToZarr:
    def create_accumulator():
    return None
    def add_input(accumulator: MultiZarrToZarr, item: dict) -> MultiZarrToZarr:
    if not accumulator:
    references = [item]
    else:
    references = [accumulator.translate(), item]
    return mzz(references)
    acc = create_accumulator()
    for ref in refs:
    acc = add_input(acc, ref)
    return acc

    def distributed_merge(refs: list[list[dict]]) -> dict:
    def merge_accumulators(accumulators: Sequence[MultiZarrToZarr]) -> MultiZarrToZarr:
    references = [a.translate() for a in accumulators]
    return mzz(references)

    def extract_output(accumulator: MultiZarrToZarr) -> dict:
    return accumulator.translate(),

    with multiprocessing.Pool(4) as p:
    accumulators: list[MultiZarrToZarr] = p.map(worker_func, refs)
    merged = merge_accumulators(accumulators)
    return extract_output(merged)

    def compare_merge_size(single_dict, multi_dict):
    single_bytes = len(json.dumps(single_dict).encode("utf-8"))
    multi_bytes = len(json.dumps(multi_dict).encode("utf-8"))
    print(f"The single process dict is {single_bytes}")
    print(f"The multi process dict is {multi_bytes}")

    def main():
    refs = load_refs("single/inputs_raw_15286.json")

    # Expected merge results
    single_merge = merge_refs(refs)
    multi_refs = [[refs[0], refs[1]], [refs[2]], [refs[3]], [refs[4]]]
    multi_merge = distributed_merge(multi_refs)

    compare_merge_size(single_merge, multi_merge)

    if __name__ == "__main__":
    main()