Skip to content

Instantly share code, notes, and snippets.

@wkcn
Created May 13, 2018 08:04
Show Gist options
  • Save wkcn/715f30defd736ce663a2f7399ffe960e to your computer and use it in GitHub Desktop.
Save wkcn/715f30defd736ce663a2f7399ffe960e to your computer and use it in GitHub Desktop.
CountOPTime-mx
import mxnet as mx
import time
from distutils.util import strtobool
BT = time.time()
class CountTimeOP(mx.operator.CustomOp):
def __init__(self, first, cname):
super(CountTimeOP, self).__init__()
self.first = first
self.cname = cname
def forward(self, is_train, req, in_data, out_data, aux):
if not self.first:
a = in_data[0].asnumpy()
dt = (time.time() - BT) - in_data[1].asscalar()
print ("%s:%f" % (self.cname, dt))
out_data[0][:] = in_data[0]
out_data[1][:] = time.time() - BT
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
in_grad[0][:] = out_grad[0]
in_grad[1][:] = 0
@mx.operator.register("CountTimeOP")
class CountTimeProp(mx.operator.CustomOpProp):
def __init__(self, first, cname):
super(CountTimeProp, self).__init__(need_top_grad = True)
self.first = strtobool(first)
self.cname = cname
def list_arguments(self):
if self.first:
return ['data']
return ['data', 't']
def list_outputs(self):
return ['out_data', 'out_t']
def infer_shape(self, in_shape):
dshape = in_shape[0]
return in_shape, [dshape, (1,)]
def create_operator(self, ctx, shapes, dtypes):
return CountTimeOP(first = self.first, cname = self.cname)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment