Created
February 12, 2023 15:22
-
-
Save lewtun/53922bcb29b025cf0a060cb3aeb1e9cd to your computer and use it in GitHub Desktop.
Sharding dataset subsets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 136, | |
"id": "e4f7bfb7-6e5a-41cf-a886-8b71235b3b91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"\n", | |
"from datasets import Dataset\n", | |
"from datasets.utils.py_utils import convert_file_size_to_int" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 137, | |
"id": "1bb9e17d-e079-4aad-86e6-71c46f7ffef1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = [\n", | |
" {\n", | |
" \"qid\": idx,\n", | |
" \"question\": \"what is the meaning of life?\",\n", | |
" \"answers\": [{\"text\": \"obviously 42\", \"id\": idx + 1}],\n", | |
" }\n", | |
" for idx in range(100000)\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 138, | |
"id": "4032589e-0bd9-478e-87ed-f607361d82f4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dataset({\n", | |
" features: ['qid', 'question', 'answers'],\n", | |
" num_rows: 100000\n", | |
"})" | |
] | |
}, | |
"execution_count": 138, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ds = Dataset.from_list(data)\n", | |
"ds" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 121, | |
"id": "a17fa8c3-8cba-4b47-a9fe-00cc76fd97d9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'qid': [0, 1],\n", | |
" 'question': ['what is the meaning of life?', 'what is the meaning of life?'],\n", | |
" 'answers': [[{'id': 1, 'text': 'obviously 42'}],\n", | |
" [{'id': 2, 'text': 'obviously 42'}]]}" | |
] | |
}, | |
"execution_count": 121, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ds[:2]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 122, | |
"id": "0df81ac7-b9ee-442c-a223-adf63d7adabd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"6800000" | |
] | |
}, | |
"execution_count": 122, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dataset_nbytes = ds._estimate_nbytes()\n", | |
"dataset_nbytes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 123, | |
"id": "a96a3efc-13e7-4bb6-a3a1-5e2c03f27a1e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"6800000" | |
] | |
}, | |
"execution_count": 123, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ds._estimate_nbytes()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 145, | |
"id": "34902b53-b2c1-4b03-97e4-e8ec9b603daa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def save_shards(dataset, path_to_repo, domain, shard_size=\"100MB\"):\n", | |
" path = Path(f\"{path_to_repo}/data/{domain}\")\n", | |
" path.mkdir(parents=True, exist_ok=True)\n", | |
" dataset_nbytes = dataset._estimate_nbytes()\n", | |
" max_shard_size = convert_file_size_to_int(shard_size)\n", | |
" num_shards = int(dataset_nbytes / max_shard_size) + 1\n", | |
" num_shards = max(num_shards, 1)\n", | |
" print(f\"Saving the dataset with {num_shards=}\")\n", | |
"\n", | |
" for shard_idx in range(num_shards):\n", | |
" sharded_ds = dataset.shard(\n", | |
" num_shards=num_shards, index=shard_idx, contiguous=True\n", | |
" )\n", | |
" sharded_ds.to_parquet(\n", | |
" f\"{path}/train-{shard_idx:05d}-of-{num_shards:05d}.parquet\"\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 146, | |
"id": "1b8da35b-958a-4c28-a983-8d0376e0ca1d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Saving the dataset with num_shards=2\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "eef94ca3f4134818a7694601047835a2", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f4a035be666e47d49e0ff3170aa089fc", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"save_shards(ds, \".\", \"ai\", shard_size=\"5MB\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 144, | |
"id": "c5f13bcc-3909-4500-bdd4-d846b3c3daa1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train-00000-of-00002.parquet train-00001-of-00002.parquet\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls train-*" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 125, | |
"id": "27093506-e562-44fd-80f6-fbf599d659f1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2" | |
] | |
}, | |
"execution_count": 125, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"max_shard_size = convert_file_size_to_int(\"5MB\")\n", | |
"num_shards = int(dataset_nbytes / max_shard_size) + 1\n", | |
"num_shards = max(num_shards, 1)\n", | |
"num_shards" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"id": "4a0b384a-df56-4009-b137-fe1a80502ff8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'qid': [50000, 50001, 50002],\n", | |
" 'question': ['what is the meaning of life?',\n", | |
" 'what is the meaning of life?',\n", | |
" 'what is the meaning of life?'],\n", | |
" 'answers': [[{'id': 50001, 'text': 'obviously 42'}],\n", | |
" [{'id': 50002, 'text': 'obviously 42'}],\n", | |
" [{'id': 50003, 'text': 'obviously 42'}]]}" | |
] | |
}, | |
"execution_count": 89, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ds.shard(num_shards=4, index=2, contiguous=True)[:3]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 126, | |
"id": "1e6dab15-4c11-4121-b3dd-5548a38c74ab", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c8b340b54d2a4366ad31f8d8f06f91d9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d71fb52be1a143868b836914c3e7eb0e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"for shard_idx in range(num_shards):\n", | |
" sharded_ds = ds.shard(num_shards=num_shards, index=shard_idx, contiguous=True)\n", | |
" sharded_ds.to_parquet(f\"train-{shard_idx:05d}-of-{num_shards:05d}.parquet\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e1d662a7-bf3f-4b89-bf23-3e03bedba1ab", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "hf", | |
"language": "python", | |
"name": "hf" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment