Skip to content

Instantly share code, notes, and snippets.

@okapies okapies/
Last active Feb 12, 2019

What would you like to do?
Benchmark script for `ndarray.take` in ChainerX
import argparse
import timeit
parser = argparse.ArgumentParser()
parser.add_argument('--number', type=int, default=20)
parser.add_argument('--device', type=str, default="native:0")
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--indices', choices=['list', 'numpy', 'chainerx'])
parser.add_argument('--data', default='mnist', help='Path to the directory that contains MNIST dataset')
args = parser.parse_args()
setup = """
import chainerx as chx
import cupy
import gzip
import numpy as np
import pathlib
def get_mnist(path, name, device):
path = pathlib.Path(path)
x_path = path / (name + '-images-idx3-ubyte.gz')
y_path = path / (name + '-labels-idx1-ubyte.gz')
with, 'rb') as fx: # skip header
# read/frombuffer is used instead of fromfile because fromfile does not
# handle gzip file correctly
x = np.frombuffer(, dtype=np.uint8).reshape(-1, 784)
with, 'rb') as fy: # skip header
y = np.frombuffer(, dtype=np.uint8)
assert x.shape[0] == y.shape[0]
x = x.astype(np.float32)
x /= 255
y = y.astype(np.int32)
if device is None:
return x, y
#return cupy.array(x), cupy.array(y)
return chx.array(x, device=device), chx.array(y, device=device)
idx_type = '{}'
device = chx.get_device('{}')
path = '{}'
batch_size = {}
print('Device: ' + str(device))
#X, Y = get_mnist(path, 'train', None)
X, Y = get_mnist(path, 'train', device)
#X_test, Y_test = get_mnist(, 't10k', device)
print('Data Type: ' + str(type(X)))
print('Index Type: ' + idx_type)
if isinstance(X, (cupy.ndarray, chx.ndarray)):
dev = X.device
class DummyDevice():
def synchronize(self):
dev = DummyDevice()
N = X.shape[0]
all_indices_np = np.arange(N, dtype=np.int64)
def gen_indices(all_indices_np, idx_type):
if idx_type == 'list':
return all_indices_np.tolist()
elif idx_type == 'numpy':
return all_indices_np
elif idx_type == 'chainerx':
return chx.array(all_indices_np, device=device)
#all_indices = gen_indices(all_indices_np, idx_type)
""".format(args.indices, args.device,, args.batch_size)
stmt = """
all_indices = gen_indices(all_indices_np, idx_type)
for i in range(0, N, batch_size):
indices = all_indices[i:i + batch_size]
X.take(indices, axis=0)
print(timeit.timeit(stmt, setup=setup, number=args.number))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.