Skip to content

Instantly share code, notes, and snippets.

@okapies
Last active January 21, 2019 10:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save okapies/207001ce67994fc7bd02736ef232aaab to your computer and use it in GitHub Desktop.
Save okapies/207001ce67994fc7bd02736ef232aaab to your computer and use it in GitHub Desktop.
Benchmark script for `Variable` in Chainer
import argparse
import timeit
parser = argparse.ArgumentParser()
parser.add_argument('--number', type=int, default=10000000)
parser.add_argument('--device', type=str, default="native:0")
parser.add_argument('--variables', type=int, default=1)
parser.add_argument('--require-grad', action="store_true", default=False)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--unsafe', action="store_true", default=False)
parser.add_argument('--chainerx', action="store_true", default=False)
args = parser.parse_args()
setup = """
from chainer import Variable
try:
from chainer.variable import _unsafe_variable
except:
pass
import chainerx as chx
import numpy as np
device = chx.get_device('{}')
is_chainerx = {}
variables = {}
require_grad = {}
batch_size = {}
vs = []
print('Device: ' + str(device))
if is_chainerx:
chx.set_default_device(device)
a = chx.array(range(0, batch_size), dtype=chx.float32)
if require_grad:
a.require_grad()
else:
a = np.array(range(0, batch_size), dtype=np.float32)
print("a: " + str(a))
""".format(args.device, args.chainerx, args.variables, args.require_grad, args.batch_size)
if args.unsafe:
stmt = "_unsafe_variable(a, requires_grad={}, is_chainerx_array={})".format(args.require_grad, args.chainerx)
else:
stmt = "Variable(a, requires_grad={})".format(args.require_grad)
print(timeit.timeit(stmt, setup=setup, number=args.number))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment