Skip to content

Instantly share code, notes, and snippets.

@sisp
Created February 24, 2015 13:59
Show Gist options
  • Save sisp/6b78d68b69727413c22b to your computer and use it in GitHub Desktop.
Save sisp/6b78d68b69727413c22b to your computer and use it in GitHub Desktop.
CrossentropySoftmax1HotWithBiasDx + local_useless_incsubtensor_alloc
diff --git a/theano/tensor/nnet/nnet.py b/theano/tensor/nnet/nnet.py
index 788213d..e468385 100644
--- a/theano/tensor/nnet/nnet.py
+++ b/theano/tensor/nnet/nnet.py
@@ -1099,7 +1099,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
return [g_dy, g_sm, g_y_idx]
def c_code_cache_version(self):
- return (3,)
+ return (4,)
def c_code(self, node, name, inp, out, sub):
dnll, sm, y_idx = inp
@@ -1128,7 +1128,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
PyErr_SetString(PyExc_ValueError, "rank error");
%(fail)s;
}
- if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(sm)s)[0])
+ if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(sm)s)[0] && PyArray_DIMS(%(dnll)s)[0] != 1)
{
PyErr_Format(PyExc_ValueError,
"dnll.shape[0] (%%ld) != sm.shape[0] (%%ld)",
@@ -1136,7 +1136,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
(long int)PyArray_DIMS(%(sm)s)[0]);
%(fail)s;
}
- if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(y_idx)s)[0])
+ if (PyArray_DIMS(%(dnll)s)[0] != PyArray_DIMS(%(y_idx)s)[0] && PyArray_DIMS(%(dnll)s)[0] != 1)
{
PyErr_Format(PyExc_ValueError,
"dnll.shape[0] (%%ld) != y_idx.shape[0] (%%ld)",
@@ -1161,7 +1161,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
for (size_t i = 0; i < PyArray_DIMS(%(dx)s)[0]; ++i)
{
- const dtype_%(dnll)s dnll_i = ((dtype_%(dnll)s*)(PyArray_BYTES(%(dnll)s) + PyArray_STRIDES(%(dnll)s)[0] * i))[0];
+ const dtype_%(dnll)s dnll_i = ((dtype_%(dnll)s*)(PyArray_BYTES(%(dnll)s) + PyArray_STRIDES(%(dnll)s)[0] * (PyArray_DIMS(%(dnll)s)[0] > 1 ? i : 0)))[0];
const %(y_idx_type) s y_i = ((%(y_idx_type)s*)(PyArray_BYTES(%(y_idx)s) + PyArray_STRIDES(%(y_idx)s)[0] * i))[0];
@@ -1736,7 +1736,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
# if the graph is valid, they have the same shape, so we
# also know that z has the right shape.
- if incr.type not in (dvector, fvector):
+ if incr.type not in (dvector, fvector) and not all(incr.broadcastable):
return
# here we know that we are incrementing some part of matrix z by a vector
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment