Created
January 12, 2014 00:36
-
-
Save jsalvatier/8378901 to your computer and use it in GitHub Desktop.
theano cumsum function (I think grad only works for 1d arrays)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CumSum(theano.Op): | |
""" | |
This class is a wrapper for numpy cumsum function | |
""" | |
def __eq__(self, other): | |
return (type(self) == type(other)) | |
def __str__(self): | |
return self.__class__.__name__ | |
def make_node(self, input, axis=-1): | |
input = theano.tensor.as_tensor_variable(input) | |
if axis is None: | |
axis = theano.Constant(theano.gof.generic, None) | |
# axis=None flattens the array before sorting | |
out_type = tensor(dtype=input.dtype, broadcastable=[False]) | |
else: | |
axis = theano.tensor.as_tensor_variable(axis) | |
out_type = input.type() | |
return theano.Apply(self, [input, axis], [out_type]) | |
def perform(self, node, inputs, output_storage): | |
a = inputs[0] | |
axis = inputs[1] | |
z = output_storage[0] | |
z[0] = np.cumsum(a, axis) | |
def grad(self, inputs, grads): | |
i, axis = inputs | |
[gi] = grads | |
axis_grad = grad_undefined(self, 1, axis, | |
"argmax is not defined for non-integer axes so" | |
" argmax(x, axis+eps) is undefined") | |
return [cumsum(gi[::-1], axis)[::-1], axis_grad] | |
def infer_shape(self, node, inputs_shapes): | |
if (isinstance(node.inputs[1], theano.Constant) and | |
node.inputs[1].data is None): | |
# That means axis = None, | |
# So the array is flattened before being sorted | |
return [(mul(*inputs_shapes[0]),)] | |
# axis should not be None | |
# So there should be the same number of dimensions | |
# in the input and output | |
assert node.inputs[0].ndim == node.outputs[0].ndim | |
assert inputs_shapes[1] == () | |
return [inputs_shapes[0]] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is now merged in Theano development version. It will be included in Theano 0.6.1