Last active
January 17, 2019 11:28
-
-
Save okapies/303c6316ae11987c17eb7a512dfb067c to your computer and use it in GitHub Desktop.
Benchmarking for `_extract_apply_in_data` in Chainer's `function_node`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import timeit | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--number', type=int, default=10000000) | |
parser.add_argument('--device', type=str, default="native:0") | |
parser.add_argument('--nargs', type=int, default=2) | |
parser.add_argument('--wrap-variable', action='store_true', default=False) | |
parser.add_argument('--batch-size', type=int, default=1) | |
args = parser.parse_args() | |
number=args.number | |
device = args.device | |
nargs=args.nargs | |
wrap_variable=args.wrap_variable | |
batch_size=args.batch_size | |
setup = """ | |
from chainer import Variable | |
from chainer.function_node import _extract_apply_in_data | |
import chainerx as chx | |
device = chx.get_device('{}') | |
nargs = {} | |
wrap_variable = {} | |
batch_size = {} | |
vs = [] | |
print('Device: ' + str(device)) | |
chx.set_default_device(device) | |
for i in range(nargs): | |
if wrap_variable: | |
vs.append(Variable(chx.array(range(i * batch_size, (i + 1) * batch_size), dtype=chx.float32))) | |
else: | |
vs.append(chx.array(range(i * batch_size, (i + 1) * batch_size), dtype=chx.float32)) | |
print("vs: " + str(vs)) | |
""".format(device, nargs, wrap_variable, batch_size) | |
stmt = "_extract_apply_in_data(vs)" | |
print(timeit.timeit(stmt, setup=setup, number=number)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment