Skip to content

Instantly share code, notes, and snippets.

@lewtun
Created February 12, 2023 15:22
Show Gist options
  • Save lewtun/53922bcb29b025cf0a060cb3aeb1e9cd to your computer and use it in GitHub Desktop.
Save lewtun/53922bcb29b025cf0a060cb3aeb1e9cd to your computer and use it in GitHub Desktop.
Sharding dataset subsets
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 136,
"id": "e4f7bfb7-6e5a-41cf-a886-8b71235b3b91",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from datasets import Dataset\n",
"from datasets.utils.py_utils import convert_file_size_to_int"
]
},
{
"cell_type": "code",
"execution_count": 137,
"id": "1bb9e17d-e079-4aad-86e6-71c46f7ffef1",
"metadata": {},
"outputs": [],
"source": [
"data = [\n",
" {\n",
" \"qid\": idx,\n",
" \"question\": \"what is the meaning of life?\",\n",
" \"answers\": [{\"text\": \"obviously 42\", \"id\": idx + 1}],\n",
" }\n",
" for idx in range(100000)\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 138,
"id": "4032589e-0bd9-478e-87ed-f607361d82f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['qid', 'question', 'answers'],\n",
" num_rows: 100000\n",
"})"
]
},
"execution_count": 138,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds = Dataset.from_list(data)\n",
"ds"
]
},
{
"cell_type": "code",
"execution_count": 121,
"id": "a17fa8c3-8cba-4b47-a9fe-00cc76fd97d9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'qid': [0, 1],\n",
" 'question': ['what is the meaning of life?', 'what is the meaning of life?'],\n",
" 'answers': [[{'id': 1, 'text': 'obviously 42'}],\n",
" [{'id': 2, 'text': 'obviously 42'}]]}"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds[:2]"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "0df81ac7-b9ee-442c-a223-adf63d7adabd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6800000"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_nbytes = ds._estimate_nbytes()\n",
"dataset_nbytes"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "a96a3efc-13e7-4bb6-a3a1-5e2c03f27a1e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6800000"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds._estimate_nbytes()"
]
},
{
"cell_type": "code",
"execution_count": 145,
"id": "34902b53-b2c1-4b03-97e4-e8ec9b603daa",
"metadata": {},
"outputs": [],
"source": [
"def save_shards(dataset, path_to_repo, domain, shard_size=\"100MB\"):\n",
" path = Path(f\"{path_to_repo}/data/{domain}\")\n",
" path.mkdir(parents=True, exist_ok=True)\n",
" dataset_nbytes = dataset._estimate_nbytes()\n",
" max_shard_size = convert_file_size_to_int(shard_size)\n",
" num_shards = int(dataset_nbytes / max_shard_size) + 1\n",
" num_shards = max(num_shards, 1)\n",
" print(f\"Saving the dataset with {num_shards=}\")\n",
"\n",
" for shard_idx in range(num_shards):\n",
" sharded_ds = dataset.shard(\n",
" num_shards=num_shards, index=shard_idx, contiguous=True\n",
" )\n",
" sharded_ds.to_parquet(\n",
" f\"{path}/train-{shard_idx:05d}-of-{num_shards:05d}.parquet\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 146,
"id": "1b8da35b-958a-4c28-a983-8d0376e0ca1d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving the dataset with num_shards=2\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eef94ca3f4134818a7694601047835a2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f4a035be666e47d49e0ff3170aa089fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"save_shards(ds, \".\", \"ai\", shard_size=\"5MB\")"
]
},
{
"cell_type": "code",
"execution_count": 144,
"id": "c5f13bcc-3909-4500-bdd4-d846b3c3daa1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train-00000-of-00002.parquet train-00001-of-00002.parquet\n"
]
}
],
"source": [
"!ls train-*"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "27093506-e562-44fd-80f6-fbf599d659f1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"max_shard_size = convert_file_size_to_int(\"5MB\")\n",
"num_shards = int(dataset_nbytes / max_shard_size) + 1\n",
"num_shards = max(num_shards, 1)\n",
"num_shards"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "4a0b384a-df56-4009-b137-fe1a80502ff8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'qid': [50000, 50001, 50002],\n",
" 'question': ['what is the meaning of life?',\n",
" 'what is the meaning of life?',\n",
" 'what is the meaning of life?'],\n",
" 'answers': [[{'id': 50001, 'text': 'obviously 42'}],\n",
" [{'id': 50002, 'text': 'obviously 42'}],\n",
" [{'id': 50003, 'text': 'obviously 42'}]]}"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds.shard(num_shards=4, index=2, contiguous=True)[:3]"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "1e6dab15-4c11-4121-b3dd-5548a38c74ab",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c8b340b54d2a4366ad31f8d8f06f91d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d71fb52be1a143868b836914c3e7eb0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Creating parquet from Arrow format: 0%| | 0/50 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for shard_idx in range(num_shards):\n",
" sharded_ds = ds.shard(num_shards=num_shards, index=shard_idx, contiguous=True)\n",
" sharded_ds.to_parquet(f\"train-{shard_idx:05d}-of-{num_shards:05d}.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1d662a7-bf3f-4b89-bf23-3e03bedba1ab",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "hf",
"language": "python",
"name": "hf"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment