Skip to content

Instantly share code, notes, and snippets.

@okapies
Last active January 17, 2019 11:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save okapies/303c6316ae11987c17eb7a512dfb067c to your computer and use it in GitHub Desktop.
Save okapies/303c6316ae11987c17eb7a512dfb067c to your computer and use it in GitHub Desktop.
Benchmarking for `_extract_apply_in_data` in Chainer's `function_node`
import argparse
import timeit
parser = argparse.ArgumentParser()
parser.add_argument('--number', type=int, default=10000000)
parser.add_argument('--device', type=str, default="native:0")
parser.add_argument('--nargs', type=int, default=2)
parser.add_argument('--wrap-variable', action='store_true', default=False)
parser.add_argument('--batch-size', type=int, default=1)
args = parser.parse_args()
number=args.number
device = args.device
nargs=args.nargs
wrap_variable=args.wrap_variable
batch_size=args.batch_size
setup = """
from chainer import Variable
from chainer.function_node import _extract_apply_in_data
import chainerx as chx
device = chx.get_device('{}')
nargs = {}
wrap_variable = {}
batch_size = {}
vs = []
print('Device: ' + str(device))
chx.set_default_device(device)
for i in range(nargs):
if wrap_variable:
vs.append(Variable(chx.array(range(i * batch_size, (i + 1) * batch_size), dtype=chx.float32)))
else:
vs.append(chx.array(range(i * batch_size, (i + 1) * batch_size), dtype=chx.float32))
print("vs: " + str(vs))
""".format(device, nargs, wrap_variable, batch_size)
stmt = "_extract_apply_in_data(vs)"
print(timeit.timeit(stmt, setup=setup, number=number))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment