Created
June 24, 2024 00:59
-
-
Save cloneofsimo/d31eee8a5352655cb45869694adf0880 to your computer and use it in GitHub Desktop.
MDS-Multiprocessed-datamerging to NFS, because writing is async this is faster
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from glob import glob | |
from tqdm import tqdm | |
from multiprocessing import Pool, Manager, cpu_count | |
def with_id(basename: str, shard_id: int) -> str: | |
parts = basename.split(".") | |
parts[1] = f"{shard_id:07}" | |
return ".".join(parts) | |
def count_shards(subdir): | |
index_filename = os.path.join(subdir, "data/index.json") | |
if not os.path.isfile(index_filename): | |
return 0 | |
obj = json.load(open(index_filename)) | |
return len(obj["shards"]) | |
def process_subdir(args): | |
subdir, start_shard_id, link_dir, shared_infos = args | |
shard_id = start_shard_id | |
index_filename = os.path.join(subdir, "data/index.json") | |
if not os.path.isfile(index_filename): | |
return shard_id | |
obj = json.load(open(index_filename)) | |
for info in obj["shards"]: | |
old_basename = info["raw_data"]["basename"] | |
new_basename = with_id(old_basename, shard_id) | |
info["raw_data"]["basename"] = new_basename | |
if info["zip_data"] is not None: | |
old_basename = info["zip_data"]["basename"] | |
new_basename = with_id(old_basename, shard_id) | |
info["zip_data"]["basename"] = new_basename | |
old_filename = os.path.join(subdir, "data", old_basename) | |
new_filename = os.path.join(link_dir, new_basename) | |
if not os.path.exists(new_filename): | |
try: | |
os.symlink(old_filename, new_filename) | |
except OSError as e: | |
print(f"Failed to create symlink {new_filename} -> {old_filename}: {e}") | |
shard_id += 1 | |
shared_infos.append(info) | |
return shared_infos | |
import itertools | |
def merge_shard_groups(root: str = "./vae_mds", link_dir: str = "./linked_mds") -> None: | |
os.makedirs(link_dir, exist_ok=True) | |
root = os.path.abspath(root) | |
link_dir = os.path.abspath(link_dir) | |
pattern = os.path.join(root, "*") | |
subdirs = sorted(glob(pattern)) | |
shard_counts = [] | |
with Pool(cpu_count()) as pool: | |
shard_counts = list(tqdm(pool.imap(count_shards, subdirs), total=len(subdirs), desc="Counting shards")) | |
shard_offsets = [0] * len(subdirs) | |
total_shards = 0 | |
for i in range(len(subdirs)): | |
shard_offsets[i] = total_shards | |
total_shards += shard_counts[i] | |
args = [(subdirs[i], shard_offsets[i], link_dir, []) for i in range(len(subdirs))] | |
with Pool(16) as pool: | |
shared_infos = list(tqdm(pool.imap(process_subdir, args), total=len(args), desc="Merging shards")) | |
index_filename = os.path.join(link_dir, "index.json") | |
obj = { | |
"version": 2, | |
# as single list | |
"shards": list(itertools.chain(*shared_infos)) | |
} | |
text = json.dumps(obj, sort_keys=True, indent=4) | |
with open(index_filename, "w") as out: | |
out.write(text) | |
merge_shard_groups('/jfs/mds_original', '/jfs/mds_relinked_0') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment