Skip to content

Instantly share code, notes, and snippets.

@gsakkis
Created February 19, 2023 12:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gsakkis/cc814007dbfb555cbfe237efc570c6e6 to your computer and use it in GitHub Desktop.
Save gsakkis/cc814007dbfb555cbfe237efc570c6e6 to your computer and use it in GitHub Desktop.
import argparse
import multiprocessing
import os
from timeit import default_timer
from typing import Optional, Sequence, Type, Union
import numpy as np
import tiledb
from tiledb.ml.readers.pytorch import PyTorchTileDBDataLoader
from tiledb.ml.readers.tensorflow import TensorflowTileDBDataset
def rand_mask(shape: Sequence[int], sparsity: float) -> np.ndarray:
rnd = np.random.default_rng()
mask = np.zeros(shape, dtype=np.bool_)
non_zero_size = int(mask.size * (1 - sparsity))
non_zero_idxs = rnd.choice(mask.size, size=non_zero_size)
mask.flat[non_zero_idxs] = True
return mask
def seq_array(shape: Sequence[int], dtype: np.dtype) -> np.ndarray:
return (
np.repeat(np.arange(shape[0]), np.prod(shape[1:])).reshape(shape).astype(dtype)
)
def create_array(
folder: str,
shape: Sequence[int],
tiles: Sequence[int],
attr_dtypes: Sequence[np.dtype],
sparsity: float,
) -> None:
assert len(shape) == len(tiles)
assert all(shape[i] >= tiles[i] for i in range(len(shape)))
sparse = sparsity > 0
filename = f"sparse-{sparsity}" if sparse else "dense"
filename += f"-shape_{'_'.join(map(str, shape))}"
filename += f"-tiles_{'_'.join(map(str, tiles))}"
filename += f"-attrs_{'_'.join(map(str, attr_dtypes))}"
uri = os.path.join(folder, filename)
print(uri)
if os.path.exists(uri):
return
schema = tiledb.ArraySchema(
sparse=sparse,
domain=tiledb.Domain(
*[
tiledb.Dim(
name=f"d{i}",
domain=(0, shape[i] - 1),
tile=tile,
dtype=np.int32,
)
for i, tile in enumerate(tiles)
]
),
attrs=[
tiledb.Attr(name=f"a{i}", dtype=dtype)
for i, dtype in enumerate(attr_dtypes)
],
)
tiledb.Array.create(uri, schema)
with tiledb.open(uri, "w") as tiledb_array:
attr_data = [seq_array(shape, dtype) for dtype in attr_dtypes]
idx = np.nonzero(rand_mask(shape, sparsity)) if sparse else slice(None)
tiledb_array[idx] = {f"a{i}": data[idx] for i, data in enumerate(attr_data)}
def read_dataset(
cls: Type[Union[TensorflowTileDBDataset, PyTorchTileDBDataLoader]],
x_uri: str,
y_uri: str,
batch_size: int,
buffer_bytes: int,
shuffle_buffer_size: int,
prefetch: Optional[int] = None,
num_workers: Optional[int] = None,
sparse_layout: Optional[str] = None,
x_key_dim: Optional[str] = None,
y_key_dim: Optional[str] = None,
config: Optional[tiledb.Config] = None,
) -> None:
kwargs = dict(
buffer_bytes=buffer_bytes,
batch_size=batch_size,
shuffle_buffer_size=shuffle_buffer_size,
x_key_dim=x_key_dim,
y_key_dim=y_key_dim,
)
if prefetch is not None:
kwargs["prefetch"] = prefetch
if num_workers is not None:
kwargs["num_workers"] = num_workers
if sparse_layout is not None:
kwargs["csr"] = sparse_layout == "csr"
x_array = tiledb.open(x_uri, config=config)
y_array = tiledb.open(y_uri, config=config)
with x_array, y_array:
# print(x_array.schema, x_array.nonempty_domain())
# print(y_array.schema, y_array.nonempty_domain())
loader = cls(x_array, y_array, **kwargs)
for _ in range(1):
time = default_timer()
print(sum(1 for _ in loader))
print(f"Elapsed time={default_timer() - time:.2f}s")
def main() -> None:
parser = argparse.ArgumentParser(
description="Benchmark reading row slices from a TileDB array",
)
subparsers = parser.add_subparsers(help="command", dest="cmd")
parser_create = subparsers.add_parser(
"create",
help="create array",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser_create.add_argument("dir", help="parent directory to store the array")
parser_create.add_argument(
"--sparsity",
type=float,
default=0.0,
help="Array sparsity in the [0, 1) range. If non zero, the array will "
"be sparse with the given sparsity ratio.",
)
parser_create.add_argument(
"--shape",
type=lambda s: tuple(map(int, s.split(","))),
default=(100_000, 100, 10),
help="comma separated ints of the array dimension sizes",
)
parser_create.add_argument(
"--tiles",
type=lambda s: tuple(map(int, s.split(","))),
default=(2**10, 2**5, 2**2),
help="comma separated ints of the array tile extents",
)
parser_create.add_argument(
"--dtypes",
type=lambda s: tuple(map(np.dtype, s.split(","))),
default=(np.dtype("uint8"), np.dtype("float32")),
help="comma separated strings of the array attribute dtypes",
)
parser_read = subparsers.add_parser("read", help="read array")
parser_read.add_argument("x_uri", help="Training data TileDB URI")
parser_read.add_argument("y_uri", help="Labels TileDB URI")
loader_type = parser_read.add_mutually_exclusive_group(required=True)
loader_type.add_argument(
"--tensorflow",
dest="type",
action="store_const",
const=TensorflowTileDBDataset,
help="Use Tensorflow loader",
)
loader_type.add_argument(
"--pytorch",
dest="type",
action="store_const",
const=PyTorchTileDBDataLoader,
help="Use PyTorch loader",
)
parser_read.add_argument(
"-b",
"--batch_size",
type=int,
default=32,
help="Size of each batch",
)
parser_read.add_argument(
"-B",
"--buffer_bytes",
type=int,
help="Maximum size (in bytes) of memory to allocate for reading from each array",
)
parser_read.add_argument(
"-s",
"--shuffle_buffer_size",
type=int,
default=0,
help="Shuffling buffer size (or 0 for no shuffling)",
)
parser_read.add_argument(
"-w",
"--num_workers",
type=int,
help="Number of workers to use for data loading (PyTorch only)",
)
parser_read.add_argument(
"-p",
"--prefetch",
type=int,
help="Number of batches to prefetch",
)
parser_read.add_argument(
"--sparse_layout",
choices=("csr", "coo"),
help="Sparse layout for for 2d sparse arrays (Pytorch only)",
)
parser_read.add_argument(
"--x_key_dim",
help="X key dimension",
)
parser_read.add_argument(
"--y_key_dim",
help="Y key dimension",
)
parser_read.add_argument(
"--memory_budget", type=int, help="sm.memory_budget config value"
)
parser_read.add_argument(
"--max_incomplete_retries",
type=int,
help="py.max_incomplete_retries config value",
)
parser_read.add_argument(
"--init_buffer_bytes",
type=int,
help="py.init_buffer_bytes config value",
)
parser_read.add_argument("--stats", action="store_true", help="dump TileDB stats")
opts = parser.parse_args()
if opts.cmd == "create":
create_array(
folder=opts.dir,
sparsity=opts.sparsity,
shape=opts.shape,
tiles=opts.tiles,
attr_dtypes=opts.dtypes,
)
elif opts.cmd == "read":
config = tiledb.Config()
# config["sm.compute_concurrency_level"] = 8
if opts.memory_budget is not None:
config["sm.memory_budget"] = opts.memory_budget
if opts.max_incomplete_retries is not None:
config["py.max_incomplete_retries"] = opts.max_incomplete_retries
if opts.init_buffer_bytes is not None:
config["py.init_buffer_bytes"] = opts.init_buffer_bytes
if opts.num_workers:
multiprocessing.set_start_method("forkserver")
read_dataset(
opts.type,
opts.x_uri,
opts.y_uri,
opts.batch_size,
opts.buffer_bytes,
opts.shuffle_buffer_size,
opts.prefetch,
opts.num_workers,
opts.sparse_layout,
opts.x_key_dim,
opts.y_key_dim,
config,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment