Skip to content

Instantly share code, notes, and snippets.

@peterroelants
Created November 20, 2014 19:23
Show Gist options
  • Save peterroelants/4946cdbf189c5e75f2b7 to your computer and use it in GitHub Desktop.
Save peterroelants/4946cdbf189c5e75f2b7 to your computer and use it in GitHub Desktop.
Theano gradient of sparse matrix mulitplication
{
"metadata": {
"name": "",
"signature": "sha256:14e43d708aea630a4c8f33c6cabbbe3227646db0888b99449a3db1735adc1a9d"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"import scipy.sparse as sp\n",
"import theano\n",
"import theano.tensor as T\n",
"import theano.sparse"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Test out regular theano computations\n",
"W = T.matrix(name='W', dtype=theano.config.floatX)\n",
"b = T.matrix(name='b', dtype=theano.config.floatX)\n",
"x = T.matrix(name='x', dtype=theano.config.floatX)\n",
"b_prime = T.matrix(name='b_prime', dtype=theano.config.floatX)\n",
"\n",
"# Define some input variables\n",
"x_vals = np.matrix([[0,0,1.0]])\n",
"W_vals = np.matrix([[1,0],[0,2],[0,3.0]])\n",
"b_vals = np.matrix([[0.5,0.2]])\n",
"b_prime_vals = np.matrix([[0.1,0.4, 0.7]])\n",
"\n",
"# Encode the input\n",
"# Matrix sizes: (1x3 . 3x2) + 1x2 = 1x2\n",
"h = T.nnet.sigmoid(x.dot(W) + b)\n",
"# Decode the output\n",
"# Matrix sizes: (1x2 . 2x3) + 1x3 = 1x3\n",
"z = T.nnet.sigmoid(h.dot(W.T) + b_prime)\n",
"\n",
"# Try do do a gradient\n",
"cost = T.sum(x * T.log(z))\n",
"gc = T.grad(cost, W)\n",
"f_gc = theano.function([x, W, b, b_prime], gc, mode='DebugMode')\n",
"gc_vals = f_gc(x_vals, W_vals, b_vals, b_prime_vals)\n",
"print type(gc_vals)\n",
"print gc_vals"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"<type 'numpy.ndarray'>\n",
"[[ 0. 0. ]\n",
" [ 0. 0. ]\n",
" [ 0.01683987 0.02904842]]\n"
]
},
{
"output_type": "stream",
"stream": "stderr",
"text": [
"WARNING: ('Stride mismatch', ((3, 1), (3, 1), (8, 24), (8, 8), 'DimShuffle{1,0}'))\n",
"WARNING: ('Stride mismatch', ((3, 1), (3, 1), (8, 24), (8, 8), 'InplaceDimShuffle{1,0}'))\n",
"WARNING: ('Stride mismatch', ((2, 1), (2, 1), (8, 16), (8, 8), 'DimShuffle{1,0}'))\n",
"WARNING: ('Stride mismatch', ((1, 2), (1, 2), (8, 8), (16, 8), 'DimShuffle{1,0}'))\n",
"WARNING: ('Stride mismatch', ((3, 1), (3, 1), (8, 24), (8, 8), 'InplaceDimShuffle{1,0}'))\n",
"WARNING: ('Stride mismatch', ((3, 1), (3, 1), (8, 24), (8, 8), 'DimShuffle{1,0}'))\n"
]
}
],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# Test out sparse theano computations\n",
"W = T.matrix(name='W', dtype=theano.config.floatX)\n",
"b = T.matrix(name='b', dtype=theano.config.floatX)\n",
"x = theano.sparse.csr_matrix(name='x', dtype=theano.config.floatX)\n",
"b_prime = T.matrix(name='b_prime', dtype=theano.config.floatX)\n",
"\n",
"# Define some input variables\n",
"x_vals = sp.csr_matrix(np.matrix([[0,0,1.0]]))\n",
"W_vals = np.matrix([[1,0],[0,2],[0,3.0]])\n",
"b_vals = np.matrix([[0.5,0.2]])\n",
"b_prime_vals = np.matrix([[0.1,0.4, 0.7]])\n",
"\n",
"# Encode the input\n",
"# (1x3 . 3x2) + 1x2 = 1x2\n",
"h = T.nnet.sigmoid(theano.sparse.basic.dot(x, W) + b)\n",
"# Decode the output\n",
"# (1x2 . 2x3) + 1x3 = 1x3\n",
"z = T.nnet.sigmoid(h.dot(W.T) + b_prime)\n",
"\n",
"# Try to do a gradient\n",
"# cost = T.sum((x-z)**2) # This works!\n",
"cost = T.sum(x * T.log(z))\n",
"gc = T.grad(cost, W)\n",
"f_gc = theano.function([x, W, b, b_prime], gc, mode='DebugMode')\n",
"gc_vals = f_gc(x_vals, W_vals, b_vals, b_prime_vals)\n",
"print type(gc_vals)\n",
"print gc_vals"
],
"language": "python",
"metadata": {},
"outputs": [
{
"ename": "AsTensorError",
"evalue": "('Variable type field must be a TensorType.', SparseVariable{csr,float64}, Sparse[float64, csr])",
"output_type": "pyerr",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mAsTensorError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-dcab653a26cb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m# Try to do a gradient\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;31m# cost = T.sum((x-z)**2) # This works!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mcost\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0mgc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcost\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mf_gc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtheano\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mW\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_prime\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'DebugMode'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/peterroelants/anaconda/lib/python2.7/site-packages/theano/tensor/basic.pyc\u001b[0m in \u001b[0;36msum\u001b[0;34m(input, axis, dtype, keepdims, acc_dtype)\u001b[0m\n\u001b[1;32m 2630\u001b[0m \"\"\"\n\u001b[1;32m 2631\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2632\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0melemwise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc_dtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0macc_dtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2634\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/peterroelants/anaconda/lib/python2.7/site-packages/theano/gof/op.pyc\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 491\u001b[0m \"\"\"\n\u001b[1;32m 492\u001b[0m \u001b[0mreturn_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'return_list'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m \u001b[0mnode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 494\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_stack_trace_on_call\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_tag_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/peterroelants/anaconda/lib/python2.7/site-packages/theano/tensor/elemwise.pyc\u001b[0m in \u001b[0;36mmake_node\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 1809\u001b[0m \u001b[0;31m# we can infer what dtype should be, and create a node from an Op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1810\u001b[0m \u001b[0;31m# of the appropriate dtype.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1811\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mas_tensor_variable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1812\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1813\u001b[0m \u001b[0macc_dtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_acc_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/peterroelants/anaconda/lib/python2.7/site-packages/theano/tensor/basic.pyc\u001b[0m in \u001b[0;36mas_tensor_variable\u001b[0;34m(x, name, ndim)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTensorType\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m raise AsTensorError(\n\u001b[0;32m--> 161\u001b[0;31m \"Variable type field must be a TensorType.\", x, x.type)\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mndim\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAsTensorError\u001b[0m: ('Variable type field must be a TensorType.', SparseVariable{csr,float64}, Sparse[float64, csr])"
]
}
],
"prompt_number": 3
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment