Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created June 24, 2024 00:59
Show Gist options
  • Save cloneofsimo/d31eee8a5352655cb45869694adf0880 to your computer and use it in GitHub Desktop.
Save cloneofsimo/d31eee8a5352655cb45869694adf0880 to your computer and use it in GitHub Desktop.
MDS-Multiprocessed-datamerging to NFS, because writing is async this is faster
import os
import json
from glob import glob
from tqdm import tqdm
from multiprocessing import Pool, Manager, cpu_count
def with_id(basename: str, shard_id: int) -> str:
parts = basename.split(".")
parts[1] = f"{shard_id:07}"
return ".".join(parts)
def count_shards(subdir):
index_filename = os.path.join(subdir, "data/index.json")
if not os.path.isfile(index_filename):
return 0
obj = json.load(open(index_filename))
return len(obj["shards"])
def process_subdir(args):
subdir, start_shard_id, link_dir, shared_infos = args
shard_id = start_shard_id
index_filename = os.path.join(subdir, "data/index.json")
if not os.path.isfile(index_filename):
return shard_id
obj = json.load(open(index_filename))
for info in obj["shards"]:
old_basename = info["raw_data"]["basename"]
new_basename = with_id(old_basename, shard_id)
info["raw_data"]["basename"] = new_basename
if info["zip_data"] is not None:
old_basename = info["zip_data"]["basename"]
new_basename = with_id(old_basename, shard_id)
info["zip_data"]["basename"] = new_basename
old_filename = os.path.join(subdir, "data", old_basename)
new_filename = os.path.join(link_dir, new_basename)
if not os.path.exists(new_filename):
try:
os.symlink(old_filename, new_filename)
except OSError as e:
print(f"Failed to create symlink {new_filename} -> {old_filename}: {e}")
shard_id += 1
shared_infos.append(info)
return shared_infos
import itertools
def merge_shard_groups(root: str = "./vae_mds", link_dir: str = "./linked_mds") -> None:
os.makedirs(link_dir, exist_ok=True)
root = os.path.abspath(root)
link_dir = os.path.abspath(link_dir)
pattern = os.path.join(root, "*")
subdirs = sorted(glob(pattern))
shard_counts = []
with Pool(cpu_count()) as pool:
shard_counts = list(tqdm(pool.imap(count_shards, subdirs), total=len(subdirs), desc="Counting shards"))
shard_offsets = [0] * len(subdirs)
total_shards = 0
for i in range(len(subdirs)):
shard_offsets[i] = total_shards
total_shards += shard_counts[i]
args = [(subdirs[i], shard_offsets[i], link_dir, []) for i in range(len(subdirs))]
with Pool(16) as pool:
shared_infos = list(tqdm(pool.imap(process_subdir, args), total=len(args), desc="Merging shards"))
index_filename = os.path.join(link_dir, "index.json")
obj = {
"version": 2,
# as single list
"shards": list(itertools.chain(*shared_infos))
}
text = json.dumps(obj, sort_keys=True, indent=4)
with open(index_filename, "w") as out:
out.write(text)
merge_shard_groups('/jfs/mds_original', '/jfs/mds_relinked_0')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment