Skip to content

Instantly share code, notes, and snippets.

@clane9
Created April 30, 2024 02:28
Show Gist options
  • Save clane9/6f12d2372ba00fb01adda1074e8c5a45 to your computer and use it in GitHub Desktop.
Save clane9/6f12d2372ba00fb01adda1074e8c5a45 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"import psutil\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"import webdataset as wds\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_memory():\n",
" proc = psutil.Process()\n",
" children = proc.children(recursive=True)\n",
" mem = []\n",
" for p in [proc] + children:\n",
" mem.append(p.memory_info().rss / 1024**2)\n",
" return mem"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate a dummy dataset consisting of sequences of high-dimensional data."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# change directory as needed\n",
"root = Path(\"/local/slurm-23665773/local/data\")\n",
"root.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def encode_numpy(data: np.ndarray) -> bytes:\n",
" with io.BytesIO() as f:\n",
" np.save(f, data)\n",
" buf = f.getvalue()\n",
" return buf"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# writing /local/slurm-23665773/local/data/000000.tar 0 0.0 GB 0\n",
"# writing /local/slurm-23665773/local/data/000001.tar 50 0.4 GB 50\n",
"# writing /local/slurm-23665773/local/data/000002.tar 50 0.4 GB 100\n",
"# writing /local/slurm-23665773/local/data/000003.tar 50 0.4 GB 150\n",
"# writing /local/slurm-23665773/local/data/000004.tar 50 0.4 GB 200\n",
"# writing /local/slurm-23665773/local/data/000005.tar 50 0.4 GB 250\n",
"# writing /local/slurm-23665773/local/data/000006.tar 50 0.4 GB 300\n",
"# writing /local/slurm-23665773/local/data/000007.tar 50 0.4 GB 350\n",
"# writing /local/slurm-23665773/local/data/000008.tar 50 0.4 GB 400\n",
"# writing /local/slurm-23665773/local/data/000009.tar 50 0.4 GB 450\n",
"# writing /local/slurm-23665773/local/data/000010.tar 50 0.4 GB 500\n",
"# writing /local/slurm-23665773/local/data/000011.tar 50 0.4 GB 550\n",
"# writing /local/slurm-23665773/local/data/000012.tar 50 0.4 GB 600\n",
"# writing /local/slurm-23665773/local/data/000013.tar 50 0.4 GB 650\n",
"# writing /local/slurm-23665773/local/data/000014.tar 50 0.4 GB 700\n",
"# writing /local/slurm-23665773/local/data/000015.tar 50 0.4 GB 750\n",
"# writing /local/slurm-23665773/local/data/000016.tar 50 0.4 GB 800\n",
"# writing /local/slurm-23665773/local/data/000017.tar 50 0.4 GB 850\n",
"# writing /local/slurm-23665773/local/data/000018.tar 50 0.4 GB 900\n",
"# writing /local/slurm-23665773/local/data/000019.tar 50 0.4 GB 950\n"
]
}
],
"source": [
"with wds.ShardWriter(str(root / \"%06d.tar\"), maxsize=400*1024*1024, encoder=False) as sink:\n",
" for ii in range(1000):\n",
" x = np.random.randint(0, 255, (256, 32768), dtype=np.uint8)\n",
" buf = encode_numpy(x)\n",
" sample = {\"__key__\": f\"{ii:05d}\", \"npy\": buf}\n",
" sink.write(sample)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test 1: iterating over full sequences, shuffled and batched.\n",
"\n",
"Here the memory usage is roughly the buffer size plus one shard."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dataset = (\n",
" wds.WebDataset(str(root / \"{000000..000019}.tar\"))\n",
" .decode()\n",
" .to_tuple(\"npy\")\n",
" .shuffle(200)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(256, 32768)\n"
]
}
],
"source": [
"x, = next(iter(dataset))\n",
"print(x.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1600.0\n"
]
}
],
"source": [
"buffer_size = 200 * 256 * 32768 / 1024 ** 2\n",
"print(buffer_size)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"loader = DataLoader(dataset.batched(8), num_workers=2, batch_size=None)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0] Shape: torch.Size([8, 256, 32768]) Mem: (970,1301,1373)\n",
"[ 10] Shape: torch.Size([8, 256, 32768]) Mem: (970,1565,1691)\n",
"[ 20] Shape: torch.Size([8, 256, 32768]) Mem: (970,1893,1976)\n",
"[ 30] Shape: torch.Size([8, 256, 32768]) Mem: (970,2053,2062)\n",
"[ 40] Shape: torch.Size([8, 256, 32768]) Mem: (970,2078,2055)\n",
"[ 50] Shape: torch.Size([8, 256, 32768]) Mem: (970,2094,2051)\n",
"[ 60] Shape: torch.Size([8, 256, 32768]) Mem: (970,2097,2055)\n",
"[ 70] Shape: torch.Size([8, 256, 32768]) Mem: (970,2113,2065)\n",
"[ 80] Shape: torch.Size([8, 256, 32768]) Mem: (970,2093,2077)\n",
"[ 90] Shape: torch.Size([8, 256, 32768]) Mem: (970,2079,2077)\n",
"[ 100] Shape: torch.Size([8, 256, 32768]) Mem: (970,2082,2097)\n",
"[ 110] Shape: torch.Size([8, 256, 32768]) Mem: (970,2002,2020)\n",
"[ 120] Shape: torch.Size([8, 256, 32768]) Mem: (970,1999,2017)\n"
]
}
],
"source": [
"for ii, (x,) in enumerate(loader):\n",
" if ii % 10 == 0:\n",
" mem = get_memory()\n",
" mem_fmt = \",\".join(f\"{v:.0f}\" for v in mem)\n",
" print(f\"[{ii:>6d}] Shape: {x.shape} Mem: ({mem_fmt})\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test 2: iterating over short clips sampled sequentially from each sequence, then shuffled. The buffer carries more samples, but equal size.\n",
"\n",
"Now, the memory usage of the data loader is much higher! Why??"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def to_clips(window: int = 16):\n",
" def _filter(source):\n",
" for x, in source:\n",
" for start in range(0, len(x) - window, window):\n",
" yield (x[start: start+window],)\n",
"\n",
" return _filter"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"dataset2 = (\n",
" wds.WebDataset(str(root / \"{000000..000019}.tar\"))\n",
" .decode()\n",
" .to_tuple(\"npy\")\n",
" .compose(to_clips(window=16))\n",
" .shuffle(3200)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(16, 32768)\n"
]
}
],
"source": [
"x, = next(iter(dataset2))\n",
"print(x.shape)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1600.0\n"
]
}
],
"source": [
"buffer_size = 3200 * 16 * 32768 / 1024 ** 2\n",
"print(buffer_size)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"loader2 = DataLoader(dataset2.batched(128), num_workers=2, batch_size=None)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0] Shape: torch.Size([128, 16, 32768]) Mem: (970,933,893)\n",
"[ 10] Shape: torch.Size([128, 16, 32768]) Mem: (970,1261,1279)\n",
"[ 20] Shape: torch.Size([128, 16, 32768]) Mem: (970,1901,1917)\n",
"[ 30] Shape: torch.Size([128, 16, 32768]) Mem: (970,2541,2566)\n",
"[ 40] Shape: torch.Size([128, 16, 32768]) Mem: (970,3174,3247)\n",
"[ 50] Shape: torch.Size([128, 16, 32768]) Mem: (933,3660,3819)\n",
"[ 60] Shape: torch.Size([128, 16, 32768]) Mem: (933,3962,4089)\n",
"[ 70] Shape: torch.Size([128, 16, 32768]) Mem: (933,4237,4156)\n",
"[ 80] Shape: torch.Size([128, 16, 32768]) Mem: (933,4130,4208)\n",
"[ 90] Shape: torch.Size([128, 16, 32768]) Mem: (933,4138,4190)\n",
"[ 100] Shape: torch.Size([128, 16, 32768]) Mem: (933,4142,4188)\n",
"[ 110] Shape: torch.Size([128, 16, 32768]) Mem: (933,4098,4124)\n"
]
}
],
"source": [
"for ii, (x,) in enumerate(loader2):\n",
" if ii % 10 == 0:\n",
" mem = get_memory()\n",
" mem_fmt = \",\".join(f\"{v:.0f}\" for v in mem)\n",
" print(f\"[{ii:>6d}] Shape: {x.shape} Mem: ({mem_fmt})\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment