Skip to content

Instantly share code, notes, and snippets.

@jxmorris12
Created January 19, 2024 23:24
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jxmorris12/eedb24a06530defb0584c819de628cc6 to your computer and use it in GitHub Desktop.
Save jxmorris12/eedb24a06530defb0584c819de628cc6 to your computer and use it in GitHub Desktop.
datasets_fast_load_from_disk.py
from typing import Iterable
import concurrent
imoprt datasets
import glob
import json
import multiprocessing
import os
def load_dataset_tables(
files: Iterable[str], num_workers: int
) -> Iterable[datasets.table.MemoryMappedTable]:
use_threads = False
if use_threads:
pool_cls = concurrent.futures.ThreadPoolExecutor
pool_kwargs = {"max_workers": num_workers}
else:
pool_cls = multiprocessing.Pool
pool_kwargs = {"processes": num_workers}
with pool_cls(**pool_kwargs) as pool:
result = list(
tqdm.tqdm(
pool.imap(datasets.table.MemoryMappedTable.from_file, files),
desc=f"Loading {len(files)} files with {num_workers} workers",
total=len(files),
)
)
return result
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
print(f"fast_load_from_disk called with path:", cache_path)
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
dataset_state_path = os.path.join(cache_path, "state.json")
with open(dataset_state_path, encoding="utf-8") as state_file:
state = json.load(state_file)
files = glob.glob(os.path.join(cache_path, "*.arrow"))
files = sorted(files)
num_workers = 16
ds_tables = load_dataset_tables(
files=files,
num_workers=num_workers
)
arrow_table = datasets.table.concat_tables(ds_tables)
split = state["_split"]
split = dataset.splits.Split(split) if split is not None else split
return datasets.Dataset(
arrow_table=arrow_table,
info=dataset_info,
split=split,
fingerprint=state["_fingerprint"],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment