Last active
October 2, 2024 16:57
-
-
Save lhoestq/8f73187a4e4b97b9bb40b561e35f6ccb to your computer and use it in GitHub Desktop.
Use Dask to write a dataset to Hugging Face in a distributed manner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import tempfile | |
from functools import partial | |
import dask.dataframe as dd | |
import pandas as pd | |
from huggingface_hub import CommitOperationAdd, HfFileSystem | |
def _preupload(df: pd.DataFrame, path: str, filesystem: HfFileSystem, **kwargs) -> pd.DataFrame: | |
resolved_path = filesystem.resolve_path(path) | |
with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file: | |
df.to_parquet(temp_file.name, **kwargs) | |
addition = CommitOperationAdd(path_in_repo=temp_file.name, path_or_fileobj=temp_file.name) | |
filesystem._api.preupload_lfs_files(repo_id=resolved_path.repo_id, additions=[addition], repo_type=resolved_path.repo_type, revision=resolved_path.revision) | |
return pd.DataFrame({"addition": pd.Series([addition], dtype="object")}) | |
def _commit(df: pd.DataFrame, path: str, filesystem: HfFileSystem, max_operations_per_commit=50) -> pd.DataFrame: | |
resolved_path = filesystem.resolve_path(path) | |
additions: list[CommitOperationAdd] = list(df["addition"]) | |
num_commits = math.ceil(len(additions) / max_operations_per_commit) | |
for shard_idx, addition in enumerate(additions): | |
addition.path_in_repo = resolved_path.path_in_repo.replace("{shard_idx:05d}", f"{shard_idx:05d}") | |
for i in range(0, num_commits): | |
operations = additions[i * max_operations_per_commit : (i + 1) * max_operations_per_commit] | |
commit_message = "Upload using Dask" + (f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "") | |
filesystem._api.create_commit(repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision, operations=operations, commit_message=commit_message) | |
return pd.DataFrame({"path": pd.Series([addition.path_in_repo for addition in additions], dtype="string")}) | |
def to_parquet(df: dd.DataFrame, path: str, **kwargs) -> None: | |
""" | |
Write Parquet files to Hugging Face in a distributed manner with Dask. | |
It uploads Parquet files in a distributed manner in two steps: | |
1. Preupload the Parquet files in parallel in a distributed manner | |
2. Commit the preuploaded files | |
Authenticate using `huggingface-cli login` or passing a token | |
using the `storage_options` argument: `storage_options={"token": "hf_xxx"}` | |
Parameters | |
---------- | |
df: dask.dataframe.DataFrame | |
DataFrame to write on Hugging Face. | |
path : str | |
Path of the file or directory. Prefix with a protocol like `hf://` to write to Hugging Face. | |
It writes Parquet files in the form "part-xxxxx.parquet", or to a single file if `path ends with ".parquet". | |
The dataset repository must exist on Hugging Face prior to uploading files. | |
**kwargs | |
Any additional kwargs are passed to pandas.to_parquet. | |
Returns | |
------- | |
DataFrame | |
DataFrame based on parquet file. | |
Examples | |
-------- | |
>>> df = dd.DataFrame.from_dict({"foo": range(5), "bar": range(5, 10)}) | |
>>> # Save to one file | |
>>> to_parquet(df, "hf://datasets/username/dataset/data.parquet") | |
>>> # OR save to a directory (possibly in many files) | |
>>> to_parquet(df, "hf://datasets/username/dataset") | |
""" | |
filesystem: HfFileSystem = kwargs.pop("filesystem", HfFileSystem(**kwargs.pop("storage_options", {}))) | |
if path.endswith(".parquet") or path.endswith(".pq"): | |
df = df.repartition(1) | |
else: | |
path += "/part-{shard_idx:05d}.parquet" | |
df.map_partitions( | |
partial(_preupload, path=path, filesystem=filesystem, **kwargs), | |
meta={"addition": "object"}, | |
).repartition(npartitions=1).map_partitions( | |
partial(_commit, path=path, filesystem=filesystem), | |
meta={"path": "string"} | |
).compute() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment