Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@okapies
Last active February 12, 2019 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save okapies/fac7bf4a5971c092bccadb2face4284f to your computer and use it in GitHub Desktop.
Save okapies/fac7bf4a5971c092bccadb2face4284f to your computer and use it in GitHub Desktop.
Benchmark script for `ndarray.take` in ChainerX
import argparse
import timeit
parser = argparse.ArgumentParser()
parser.add_argument('--number', type=int, default=20)
parser.add_argument('--device', type=str, default="native:0")
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--indices', choices=['list', 'numpy', 'chainerx'])
parser.add_argument('--data', default='mnist', help='Path to the directory that contains MNIST dataset')
args = parser.parse_args()
setup = """
import chainerx as chx
import cupy
import gzip
import numpy as np
import pathlib
def get_mnist(path, name, device):
path = pathlib.Path(path)
x_path = path / (name + '-images-idx3-ubyte.gz')
y_path = path / (name + '-labels-idx1-ubyte.gz')
with gzip.open(x_path, 'rb') as fx:
fx.read(16) # skip header
# read/frombuffer is used instead of fromfile because fromfile does not
# handle gzip file correctly
x = np.frombuffer(fx.read(), dtype=np.uint8).reshape(-1, 784)
with gzip.open(y_path, 'rb') as fy:
fy.read(8) # skip header
y = np.frombuffer(fy.read(), dtype=np.uint8)
assert x.shape[0] == y.shape[0]
x = x.astype(np.float32)
x /= 255
y = y.astype(np.int32)
if device is None:
return x, y
#return cupy.array(x), cupy.array(y)
else:
return chx.array(x, device=device), chx.array(y, device=device)
idx_type = '{}'
device = chx.get_device('{}')
path = '{}'
batch_size = {}
print('Device: ' + str(device))
chx.set_default_device(device)
#X, Y = get_mnist(path, 'train', None)
X, Y = get_mnist(path, 'train', device)
#X_test, Y_test = get_mnist(args.data, 't10k', device)
print('Data Type: ' + str(type(X)))
print('Index Type: ' + idx_type)
if isinstance(X, (cupy.ndarray, chx.ndarray)):
dev = X.device
else:
class DummyDevice():
def synchronize(self):
pass
dev = DummyDevice()
N = X.shape[0]
all_indices_np = np.arange(N, dtype=np.int64)
np.random.seed(seed=0)
def gen_indices(all_indices_np, idx_type):
np.random.shuffle(all_indices_np)
if idx_type == 'list':
return all_indices_np.tolist()
elif idx_type == 'numpy':
return all_indices_np
elif idx_type == 'chainerx':
return chx.array(all_indices_np, device=device)
#all_indices = gen_indices(all_indices_np, idx_type)
""".format(args.indices, args.device, args.data, args.batch_size)
stmt = """
all_indices = gen_indices(all_indices_np, idx_type)
for i in range(0, N, batch_size):
indices = all_indices[i:i + batch_size]
X.take(indices, axis=0)
#X[indices]
dev.synchronize()
"""
print(timeit.timeit(stmt, setup=setup, number=args.number))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment