Skip to content

Instantly share code, notes, and snippets.

@Helw150
Created May 7, 2024 21:41
Show Gist options
  • Save Helw150/3fd982d0a95ca6983124e6249d838fd5 to your computer and use it in GitHub Desktop.
Save Helw150/3fd982d0a95ca6983124e6249d838fd5 to your computer and use it in GitHub Desktop.
def _push_parquet_shards_to_hub( [1071/1877]
self,
repo_id: str,
data_dir: str = "data",
split: Optional[str] = None,
token: Optional[str] = None,
revision: Optional[str] = None,
create_pr: Optional[bool] = False,
max_shard_size: Optional[Union[int, str]] = None,
num_shards: Optional[int] = None,
embed_external_files: bool = True,
) -> Tuple[str, str, int, int, List[str], int]:
"""Pushes the dataset shards as Parquet files to the hub.
Returns:
additions (`List[CommitOperation]`): list of the `CommitOperationAdd` of the uploaded shards
uploaded_size (`int`): number of uploaded bytes to the repository
dataset_nbytes (`int`): approximate size in bytes of the uploaded dataset afer uncompression
"""
# Find decodable columns, because if there are any, we need to:
# embed the bytes from the files in the shards
decodable_columns = (
[k for k, v in self._info.features.items() if require_decoding(v, ignore_decode_attribute=True)]
if embed_external_files
else []
)
dataset_nbytes = self._estimate_nbytes()
if num_shards is None:
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
num_shards = int(dataset_nbytes / max_shard_size) + 1
num_shards = max(num_shards, 1)
shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards))
if decodable_columns:
def shards_with_embedded_external_files(shards):
for i, shard in enumerate(shards):
if i <= 2968 or i >= 3001:
yield shard
continue
format = shard.format
shard = shard.with_format("arrow")
shard = shard.map(
embed_table_storage,
batched=True,
batch_size=1000,
keep_in_memory=True,
)
shard = shard.with_format(**format)
yield shard
shards = shards_with_embedded_external_files(shards)
api = HfApi(endpoint=config.HF_ENDPOINT, token=token)
uploaded_size = 0
finished_upload_size = 0
additions = []
part = 0
for index, shard in hf_tqdm(
enumerate(shards),
desc="Uploading the dataset shards",
total=num_shards,
):
if index <= 2968 or index >= 3001:
continue
shard_path_in_repo =f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet"
buffer = BytesIO()
shard.to_parquet(buffer)
uploaded_size += buffer.tell()
shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=buffer)
api.preupload_lfs_files(
repo_id=repo_id,
additions=[shard_addition],
repo_type="dataset",
revision=revision,
create_pr=create_pr,
)
additions.append(shard_addition)
if uploaded_size - finished_upload_size > 10_000_000_000:
_ = api.create_commit(
repo_id,
operations=additions,
commit_message=f"Dataset Upload - Part {part}",
token=token,
repo_type="dataset",
revision=revision,
)
part += 1
additions = []
finished_upload_size += (uploaded_size-finished_upload_size)
return additions, uploaded_size, dataset_nbytes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment