Last active
February 12, 2019 09:59
-
-
Save okapies/fac7bf4a5971c092bccadb2face4284f to your computer and use it in GitHub Desktop.
Benchmark script for `ndarray.take` in ChainerX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import timeit | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--number', type=int, default=20) | |
parser.add_argument('--device', type=str, default="native:0") | |
parser.add_argument('--batch-size', type=int, default=1) | |
parser.add_argument('--indices', choices=['list', 'numpy', 'chainerx']) | |
parser.add_argument('--data', default='mnist', help='Path to the directory that contains MNIST dataset') | |
args = parser.parse_args() | |
setup = """ | |
import chainerx as chx | |
import cupy | |
import gzip | |
import numpy as np | |
import pathlib | |
def get_mnist(path, name, device): | |
path = pathlib.Path(path) | |
x_path = path / (name + '-images-idx3-ubyte.gz') | |
y_path = path / (name + '-labels-idx1-ubyte.gz') | |
with gzip.open(x_path, 'rb') as fx: | |
fx.read(16) # skip header | |
# read/frombuffer is used instead of fromfile because fromfile does not | |
# handle gzip file correctly | |
x = np.frombuffer(fx.read(), dtype=np.uint8).reshape(-1, 784) | |
with gzip.open(y_path, 'rb') as fy: | |
fy.read(8) # skip header | |
y = np.frombuffer(fy.read(), dtype=np.uint8) | |
assert x.shape[0] == y.shape[0] | |
x = x.astype(np.float32) | |
x /= 255 | |
y = y.astype(np.int32) | |
if device is None: | |
return x, y | |
#return cupy.array(x), cupy.array(y) | |
else: | |
return chx.array(x, device=device), chx.array(y, device=device) | |
idx_type = '{}' | |
device = chx.get_device('{}') | |
path = '{}' | |
batch_size = {} | |
print('Device: ' + str(device)) | |
chx.set_default_device(device) | |
#X, Y = get_mnist(path, 'train', None) | |
X, Y = get_mnist(path, 'train', device) | |
#X_test, Y_test = get_mnist(args.data, 't10k', device) | |
print('Data Type: ' + str(type(X))) | |
print('Index Type: ' + idx_type) | |
if isinstance(X, (cupy.ndarray, chx.ndarray)): | |
dev = X.device | |
else: | |
class DummyDevice(): | |
def synchronize(self): | |
pass | |
dev = DummyDevice() | |
N = X.shape[0] | |
all_indices_np = np.arange(N, dtype=np.int64) | |
np.random.seed(seed=0) | |
def gen_indices(all_indices_np, idx_type): | |
np.random.shuffle(all_indices_np) | |
if idx_type == 'list': | |
return all_indices_np.tolist() | |
elif idx_type == 'numpy': | |
return all_indices_np | |
elif idx_type == 'chainerx': | |
return chx.array(all_indices_np, device=device) | |
#all_indices = gen_indices(all_indices_np, idx_type) | |
""".format(args.indices, args.device, args.data, args.batch_size) | |
stmt = """ | |
all_indices = gen_indices(all_indices_np, idx_type) | |
for i in range(0, N, batch_size): | |
indices = all_indices[i:i + batch_size] | |
X.take(indices, axis=0) | |
#X[indices] | |
dev.synchronize() | |
""" | |
print(timeit.timeit(stmt, setup=setup, number=args.number)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment