Skip to content

Instantly share code, notes, and snippets.

@noskill
Last active August 29, 2015 14:20
Show Gist options
  • Save noskill/2b9ba17cccbb402959e8 to your computer and use it in GitHub Desktop.
Save noskill/2b9ba17cccbb402959e8 to your computer and use it in GitHub Desktop.
test sort op for theano
import numpy
import theano
import theano.tensor as T
theano.config.compute_test_value = 'raise'
theano.config.optimizer='None'
ar = numpy.random.random((2,3, 2))
ar += 1
x = T.dtensor3('x')
x.tag.test_value = ar
axis = -2
y = x**2
y_sort = T.sort(y, axis=axis)
y_slice = T.sum(y_sort[:,-1:, :])
f = theano.function([x], T.grad(T.sum(y), x))
f_slice = theano.function([x], T.grad(y_slice, x))
f_sort = theano.function([x], T.grad(T.sum(y_sort), x))
print('source array:')
print(ar)
print('argsort:')
print(numpy.argsort(ar ** 2, axis=axis))
print("grad with sort, slice last elements:")
print(f_slice(ar).reshape(ar.shape))
print('grad without sort:')
print(f(ar))
assert((f_sort(ar) == f(ar)).all())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment