Skip to content

Instantly share code, notes, and snippets.

@sisp
Created March 8, 2015 16:51
Show Gist options
  • Save sisp/52b9a757e76080771a5f to your computer and use it in GitHub Desktop.
Save sisp/52b9a757e76080771a5f to your computer and use it in GitHub Desktop.
diff --git a/theano/tensor/basic.py b/theano/tensor/basic.py
index 8f734e2..52327a6 100644
--- a/theano/tensor/basic.py
+++ b/theano/tensor/basic.py
@@ -3804,6 +3804,202 @@ def vertical_stack(*args):
return concatenate(args, axis=0)
+class FixUnknownDimension(Op):
+ """Infer an unknown dimension indicated by `-1`.
+
+ In `Reshape` one dimension can be provided as `-1` which means the size of
+ this dimension is inferred. This op computes the missing dimension.
+ """
+
+ def __init__(self, ndim):
+ self.ndim = ndim
+
+ def __eq__(self, other):
+ return (type(other) is type(self)) and (other.ndim == self.ndim)
+
+ def __hash__(self):
+ return hash(type(self)) ^ hash(self.ndim)
+
+ def __str__(self):
+ return '%s{%s}' % (self.__class__.__name__, self.ndim)
+
+ def make_node(self, newshape, size):
+ newshape = as_tensor_variable(newshape, ndim=1)
+ if not newshape.dtype.startswith('int'):
+ raise TypeError('`newshape` must be integers',
+ newshape, newshape.dtype)
+ assert newshape.ndim == 1
+
+ size = as_tensor_variable(size, ndim=0)
+ if not size.dtype.startswith('int'):
+ raise TypeError('`size` must be an integer', shape, shape.dtype)
+ assert size.ndim == 0
+
+ return gof.Apply(self, [newshape, size], [newshape.type()])
+
+ def perform(self, node, inp, out_):
+ newshape, size = inp
+ out, = out_
+
+ if newshape.ndim != self.ndim:
+ raise ValueError('Argument `newshape` to '
+ 'FixUnknownDimension.perform has incorrect '
+ 'length %d, should be %d.'
+ % (newshape.ndim, self.ndim), newshape)
+
+ if size.ndim != 0:
+ raise ValueError('Argument `size` to FixUnknownDimension.perform '
+ 'must be a scalar (0 dimensions). (%d dimensions)'
+ % size.ndim)
+
+ i_unknown = newshape < 0
+ n_unknown = i_unknown.sum()
+
+ if (out[0] is None) or (out[0].shape != newshape.shape):
+ out[0] = numpy.empty_like(newshape)
+
+ out[0][:] = newshape
+
+ if n_unknown == 0:
+ if newshape.prod() != size:
+ raise ValueError('Total size must not change.')
+ elif n_unknown == 1:
+ known = newshape[~i_unknown].prod()
+ if (known > 0) and (size % known == 0):
+ out[0][i_unknown] = size // known
+ else:
+ raise ValueError('Total size must not change.')
+ else:
+ raise ValueError('Can only specify one unknown dimension.')
+
+ def infer_shape(self, node, ishapes):
+ return [ishapes[0]]
+
+ def c_code_cache_version(self):
+ return (1,)
+
+ def c_support_code(self):
+ """
+ This code is borrowed from <numpy/core/src/multiarray/shape.c>.
+ """
+ return """
+ static int
+ _fix_unknown_dimension(PyArray_Dims *newshape, npy_intp s_original)
+ {
+ npy_intp *dimensions;
+ npy_intp i_unknown, s_known;
+ int i, n;
+ static char msg[] = "total size of new array must be unchanged";
+
+ dimensions = newshape->ptr;
+ n = newshape->len;
+ s_known = 1;
+ i_unknown = -1;
+
+ for (i = 0; i < n; i++) {
+ if (dimensions[i] < 0) {
+ if (i_unknown == -1) {
+ i_unknown = i;
+ }
+ else {
+ PyErr_SetString(PyExc_ValueError,
+ "can only specify one" \
+ " unknown dimension");
+ return -1;
+ }
+ }
+ else {
+ s_known *= dimensions[i];
+ }
+ }
+
+ if (i_unknown >= 0) {
+ if ((s_known == 0) || (s_original % s_known != 0)) {
+ PyErr_SetString(PyExc_ValueError, msg);
+ return -1;
+ }
+ dimensions[i_unknown] = s_original/s_known;
+ }
+ else {
+ if (s_original != s_known) {
+ PyErr_SetString(PyExc_ValueError, msg);
+ return -1;
+ }
+ }
+ return 0;
+ }
+ """
+
+ def c_code(self, node, name, inputs, outputs, sub):
+ newshape, size = inputs
+ out, = outputs
+ ndim = self.ndim
+ dtype_newshape = node.inputs[0].type.dtype_specs()[1]
+ dtype_size = node.inputs[1].type.dtype_specs()[1]
+ fail = sub['fail']
+ return """
+ if (PyArray_DIMS(%(newshape)s)[0] != %(ndim)s)
+ {
+ PyErr_Format(PyExc_ValueError,
+ "Argument `newshape` to FixUnknownDimension.c_code "
+ "has incorrect length %%ld, should be %%ld.",
+ (long int)PyArray_DIMS(%(newshape)s)[0],
+ (long int)%(ndim)s);
+ %(fail)s;
+ }
+
+ if (PyArray_NDIM(%(size)s) != 0)
+ {
+ PyErr_Format(PyExc_ValueError,
+ "Argument `size` to FixUnknownDimension.c_code must "
+ "be a scalar (0 dimensions). (%%ld dimensions)",
+ (long int)PyArray_NDIM(%(size)s));
+ %(fail)s;
+ }
+
+ // Check if output memory can be reused. If not, allocate new memory.
+ if ((NULL == %(out)s) || (PyArray_DIMS(%(out)s)[0] != %(ndim)s))
+ {
+ if (NULL != %(out)s)
+ Py_XDECREF(%(out)s);
+
+ %(out)s = (PyArrayObject*) PyArray_SimpleNew(
+ PyArray_NDIM(%(newshape)s),
+ PyArray_DIMS(%(newshape)s),
+ NPY_INTP);
+
+ if (!%(out)s)
+ {
+ PyErr_SetString(PyExc_MemoryError, "Failed to alloc output.");
+ %(fail)s
+ }
+ }
+
+ PyArray_Dims newshape;
+ newshape.ptr = (npy_intp*)PyArray_DATA(%(out)s);
+ newshape.len = %(ndim)s;
+ for (int i = 0; i < %(ndim)s; ++i)
+ {
+ // -- We do not want an explicit cast here. `newshape` can be any
+ // -- int* dtype. The compiler will explicitly upcast it, but
+ // -- will err if this will downcast. This could happen if the
+ // -- user pass an int64 dtype, but npy_intp endup being int32.
+ newshape.ptr[i] = ((%(dtype_newshape)s*)(
+ PyArray_BYTES(%(newshape)s) +
+ PyArray_STRIDES(%(newshape)s)[0] * i))[0];
+ }
+
+ {
+ const npy_intp size = *((%(dtype_size)s*)PyArray_BYTES(%(size)s));
+ if (_fix_unknown_dimension(&newshape, size) < 0) {
+ // The error message should have been set by
+ // `_fix_unknown_dimension`.
+ %(fail)s;
+ }
+ }
+ """ % locals()
+
+
class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp.
@@ -3892,54 +4088,9 @@ class Reshape(Op):
return self(eval_points[0], *inputs[1:], **dict(return_list=True))
def infer_shape(self, node, ishapes):
- # inputs[1] can contain at most one value of '-1', meaning the actual
- # shape of the output will be automatically computed by reshape, so
- # that the total number of elements stays the same.
- # TODO: Maybe put that formula here?
- # It's not trivial, because we would have to check if the product of
- # all the non-minus-one shapes is a divisor of the product of the
- # original shapes.
-
- # The following expression leads to cycles in feature_shape,
- # because it tries to replace the Shape_i node by the switch
- # statement, which depends on Shape_i.
- # return [tuple([switch(eq(node.inputs[1][i], -1),
- # theano.tensor.opt.Shape_i(i)(node.outputs[0]),
- # node.inputs[1][i])
- # for i in xrange(self.ndim)]
- # )]
-
- # Here, we only simplify if the shape (node.inputs[1]) is a constant,
- # ideally it would suffice to check that it is always non-negative.
-
- requ = node.inputs[1]
- if isinstance(requ, theano.tensor.TensorConstant):
- requ = list(requ.data)
- requ_part = [ele for ele in requ if ele != -1]
- crit = len(requ) - len(requ_part)
- if crit == 1 and len(requ_part) > 0:
- missing = mul(*ishapes[0]) // mul(*requ_part)
- for i, ele in enumerate(requ):
- if ele == -1:
- requ[i] = missing
- elif crit == 1: # we reshape to -1
- requ = [mul(*ishapes[0])]
- elif crit > 1:
- raise ValueError('shape argument to Reshape.perform'
- ' must have at most one entry equal to -1')
- return [requ]
- else:
- oshape = []
- for i in xrange(self.ndim):
- default_os_i = theano.tensor.opt.Shape_i(i)(node.outputs[0])
- try:
- os_i = get_scalar_constant_value(node.inputs[1][i]).item()
- if os_i == -1:
- os_i = default_os_i
- except NotScalarConstantError:
- os_i = default_os_i
- oshape.append(os_i)
- return [tuple(oshape)]
+ newshape = node.inputs[1]
+ outshape = FixUnknownDimension(self.ndim)(newshape, mul(*ishapes[0]))
+ return [tuple(outshape[i] for i in xrange(self.ndim))]
def c_code_cache_version(self):
return (6,)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment