Skip to content

Instantly share code, notes, and snippets.

@dwf
Created January 18, 2014 00:16
Show Gist options
  • Save dwf/8484148 to your computer and use it in GitHub Desktop.
Save dwf/8484148 to your computer and use it in GitHub Desktop.
A skeleton cumulative sum op for Theano.
import theano
import numpy
class CumSumOp(theano.Op):
"""Use as CumSumOp(axis)(input_arg)"""
def __init__(self, axis=None):
self.axis = axis
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
def __hash__(self):
return hash(type(self)) ^ hash(self.axis)
def make_node(self, x):
x_ = theano.tensor.as_tensor_variable(x)
axis = self.axis
assert axis is None or (axis >= -x_.ndim and axis < x.ndim)
return theano.Apply(self, inputs=[x_], outputs=[x_.type()])
# using x_.type() is dangerous, it copies x's broadcasting behaviour
def perform(self, node, inputs, output_storage):
output_storage[0][0] = numpy.cumsum(inputs[0], self.axis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment