Created
July 8, 2016 20:53
-
-
Save vene/8bfee84a1c7993f9cacdea6f30ad8896 to your computer and use it in GitHub Desktop.
sparsemax loss for Theano
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Author: Vlad Niculae <vlad@vene.ro>\n", | |
"# License: matching sparsemax_theano.py from https://github.com/Unbabel/sparsemax\n", | |
"\n", | |
"import numpy as np\n", | |
"import theano.tensor as T\n", | |
"import theano as tn\n", | |
"from theano import gof\n", | |
"\n", | |
"import sparsemax_theano" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"class SparsemaxLossGrad(gof.Op):\n", | |
" \n", | |
" __props__ = ()\n", | |
" \n", | |
" itypes = [T.dmatrix, T.ivector]\n", | |
" otypes = [T.dmatrix]\n", | |
" \n", | |
" def perform(self, node, input_storage, output_storage):\n", | |
" sm, y = input_storage\n", | |
" sm[np.arange(sm.shape[0]), y] -= 1\n", | |
" output_storage[0][0] = sm\n", | |
" \n", | |
" def grad(self, *args):\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" def infer_shape(self, node, shape):\n", | |
" return [shape[0]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sparsemax_loss_grad = SparsemaxLossGrad()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"class SparsemaxLoss(gof.Op):\n", | |
"\n", | |
" __props__ = ()\n", | |
" itypes = [T.dmatrix, T.ivector]\n", | |
" otypes = [T.dvector]\n", | |
"\n", | |
" def perform(self, node, input_storage, output_storage):\n", | |
" x, y = input_storage\n", | |
" \n", | |
" loss = -x[np.arange(x.shape[0]), y]\n", | |
" \n", | |
" for i in range(x.shape[0]):\n", | |
" sm, tau, _ = sparsemax_theano.project_onto_simplex(x[i])\n", | |
" loss[i] += 0.5 * np.sum(x[i][np.abs(sm) > 1e-12] ** 2 - tau ** 2)\n", | |
" \n", | |
" loss += 0.5\n", | |
" output_storage[0][0] = loss\n", | |
"\n", | |
" def grad(self, inp, grads):\n", | |
" x, y = inp\n", | |
" g_sm = grads[0]\n", | |
" sm = sparsemax_theano.sparsemax(x)\n", | |
" g = sparsemax_loss_grad(sm, y)\n", | |
" return [g, T.zeros_like(y, dtype=tn.config.floatX)]\n", | |
"\n", | |
" def R_op(self, inputs, eval_points):\n", | |
" if None in eval_points:\n", | |
" return [None]\n", | |
" return self.grad(inputs, eval_points)\n", | |
"\n", | |
" def infer_shape(self, node, shapes):\n", | |
" x_shape, y_shape = shapes\n", | |
" return [(x_shape[0],)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sparsemax_loss = SparsemaxLoss()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 1. 0. 0.]\n", | |
"[[ 1. 0. -1.]\n", | |
" [ 0. 0. 0.]\n", | |
" [ 0. 0. 0.]]\n" | |
] | |
} | |
], | |
"source": [ | |
"Z = np.array([[3, 1, 2], [1, 3, 2], [1, 3, 2]], dtype=np.double)\n", | |
"ytr = np.array([2, 1, 1], dtype=np.int32)\n", | |
"\n", | |
"Z_tn = T.dmatrix()\n", | |
"y_tn = T.ivector()\n", | |
"\n", | |
"loss = sparsemax_loss(Z_tn, y_tn)\n", | |
"print(loss.eval({Z_tn: Z, y_tn: ytr}))\n", | |
"\n", | |
"\n", | |
"h = T.grad(cost=loss.mean(), wrt=Z_tn)\n", | |
"\n", | |
"print(h.eval({Z_tn: Z, y_tn: ytr}))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.datasets import make_classification\n", | |
"from sklearn.cross_validation import train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"X, y = make_classification(n_samples=5000, n_features=300, n_informative=100,\n", | |
" n_classes=10, random_state=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.1, random_state=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from lasagne.layers import InputLayer, DenseLayer, get_output, get_all_params\n", | |
"from lasagne.nonlinearities import softmax, rectify\n", | |
"from lasagne.objectives import categorical_crossentropy\n", | |
"from lasagne.updates import adam\n", | |
"\n", | |
"class OneLayerNN(object):\n", | |
" def __init__(self, n_features, n_hidden, n_classes, loss='softmax'):\n", | |
" \n", | |
" X = InputLayer(shape=(None, n_features))\n", | |
" y = T.ivector()\n", | |
"\n", | |
" hidden = DenseLayer(X, n_hidden, nonlinearity=rectify)\n", | |
" final = DenseLayer(hidden, n_classes, nonlinearity=None)\n", | |
" \n", | |
" out = get_output(final)\n", | |
" if loss == 'softmax':\n", | |
" y_pred = softmax(out)\n", | |
" loss = categorical_crossentropy(y_pred, y).mean()\n", | |
" \n", | |
" elif loss == 'sparsemax':\n", | |
" y_pred = sparsemax_theano.sparsemax(out)\n", | |
" loss = sparsemax_loss(out, y).mean()\n", | |
" self.train = tn.function([X.input_var, y],\n", | |
" loss,\n", | |
" updates=adam(loss,\n", | |
" get_all_params(final)))\n", | |
" \n", | |
" self.predict = tn.function([X.input_var], y_pred) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def score(nn, X, y):\n", | |
" return np.mean(np.argmax(nn.predict(X), axis=1) == y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"nn = OneLayerNN(n_features=X.shape[1], n_hidden=50, n_classes=10, loss=\"sparsemax\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Acc: 0.11\n", | |
". loss=7.123902198078001 train acc=0.12 test_acc=0.11\n", | |
". loss=4.507310067407898 train acc=0.17 test_acc=0.14\n", | |
". loss=2.9644636319351263 train acc=0.24 test_acc=0.17\n", | |
". loss=2.019472765280979 train acc=0.31 test_acc=0.23\n", | |
". loss=1.4372722258354609 train acc=0.38 test_acc=0.24\n", | |
". loss=1.0558284703777132 train acc=0.46 test_acc=0.26\n", | |
". loss=0.796143370787702 train acc=0.52 test_acc=0.27\n", | |
". loss=0.612050470774999 train acc=0.59 test_acc=0.29\n", | |
". loss=0.4773917717006552 train acc=0.65 test_acc=0.31\n", | |
". loss=0.37463887846847105 train acc=0.69 test_acc=0.31\n", | |
". loss=0.29680369993601546 train acc=0.74 test_acc=0.31\n", | |
". loss=0.23645351396738554 train acc=0.79 test_acc=0.31\n", | |
". loss=0.1889699959953699 train acc=0.83 test_acc=0.32\n", | |
". loss=0.15135058949980768 train acc=0.86 test_acc=0.32\n", | |
". loss=0.12157437481055114 train acc=0.90 test_acc=0.33\n", | |
". loss=0.09787664604424022 train acc=0.92 test_acc=0.33\n", | |
". loss=0.0791015285658215 train acc=0.94 test_acc=0.33\n", | |
". loss=0.06403825974546547 train acc=0.96 test_acc=0.34\n", | |
". loss=0.051886278683072204 train acc=0.98 test_acc=0.34\n", | |
". loss=0.04212034399920099 train acc=0.98 test_acc=0.34\n", | |
". loss=0.03422095929256212 train acc=0.99 test_acc=0.34\n", | |
". loss=0.027750474473344975 train acc=0.99 test_acc=0.35\n", | |
". loss=0.022496616959196115 train acc=1.00 test_acc=0.35\n", | |
". loss=0.0182753664050702 train acc=1.00 test_acc=0.35\n", | |
". loss=0.014869399895866994 train acc=1.00 test_acc=0.35\n", | |
". loss=0.012095314696915702 train acc=1.00 test_acc=0.35\n", | |
". loss=0.00982814442875535 train acc=1.00 test_acc=0.36\n", | |
". loss=0.007967124522815386 train acc=1.00 test_acc=0.36\n", | |
". loss=0.006439375318507309 train acc=1.00 test_acc=0.36\n", | |
". loss=0.005197087487190476 train acc=1.00 test_acc=0.36\n", | |
". loss=0.004188851208431322 train acc=1.00 test_acc=0.37\n", | |
". loss=0.003372376354147401 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0027101630654043007 train acc=1.00 test_acc=0.36\n", | |
". loss=0.002173077352041258 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0017359028478273414 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0013822257157413485 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0010974183963595831 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0008689256716532196 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0006857459089074381 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0005391318658554301 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0004227504209984751 train acc=1.00 test_acc=0.36\n", | |
". loss=0.00033068982348024944 train acc=1.00 test_acc=0.36\n", | |
". loss=0.00025798704142677634 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0002007413460167255 train acc=1.00 test_acc=0.36\n", | |
". loss=0.0001558219465871158 train acc=1.00 test_acc=0.36\n", | |
". loss=0.00012062046262150352 train acc=1.00 test_acc=0.36\n", | |
". loss=9.307841869440112e-05 train acc=1.00 test_acc=0.36\n", | |
". loss=7.16323119402043e-05 train acc=1.00 test_acc=0.36\n", | |
". loss=5.496665825257861e-05 train acc=1.00 test_acc=0.36\n", | |
". loss=4.204840809181516e-05 train acc=1.00 test_acc=0.36\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Acc: {:.2f}\".format(score(nn, X_tr, y_tr)))\n", | |
"for it in range(500):\n", | |
" loss = nn.train(X_tr, y_tr.astype(np.int32))\n", | |
" if it % 10 == 0:\n", | |
" train_acc = score(nn, X_tr, y_tr)\n", | |
" test_acc = score(nn, X_te, y_te)\n", | |
" print(\". loss={} train acc={:.2f} test_acc={:.2f}\".format(loss, train_acc, test_acc))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0. , 0. , 0. , 0. , 0. ,\n", | |
" 0. , 1. , 0. , 0. , 0. ],\n", | |
" [ 0. , 1. , 0. , 0. , 0. ,\n", | |
" 0. , 0. , 0. , 0. , 0. ],\n", | |
" [ 0. , 0. , 0. , 0. , 0. ,\n", | |
" 0. , 0. , 0. , 1. , 0. ],\n", | |
" [ 0.74288784, 0.02723951, 0. , 0. , 0. ,\n", | |
" 0. , 0. , 0.22987264, 0. , 0. ],\n", | |
" [ 1. , 0. , 0. , 0. , 0. ,\n", | |
" 0. , 0. , 0. , 0. , 0. ]])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"nn.predict(X_te)[:5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Acc: 0.08\n", | |
". loss=9.01976599621005 train acc=0.09 test_acc=0.09\n", | |
". loss=5.970537785723535 train acc=0.12 test_acc=0.10\n", | |
". loss=4.027242153532298 train acc=0.19 test_acc=0.15\n", | |
". loss=2.906558256956219 train acc=0.28 test_acc=0.20\n", | |
". loss=2.254216397641683 train acc=0.37 test_acc=0.23\n", | |
". loss=1.819587603325597 train acc=0.45 test_acc=0.27\n", | |
". loss=1.5255773090072133 train acc=0.51 test_acc=0.28\n", | |
". loss=1.311435933493566 train acc=0.58 test_acc=0.30\n", | |
". loss=1.1463621981329986 train acc=0.63 test_acc=0.31\n", | |
". loss=1.0153323742794962 train acc=0.68 test_acc=0.32\n", | |
". loss=0.9075062673000046 train acc=0.72 test_acc=0.32\n", | |
". loss=0.81651455111882 train acc=0.76 test_acc=0.32\n", | |
". loss=0.738993701613929 train acc=0.79 test_acc=0.32\n", | |
". loss=0.6722895362147565 train acc=0.81 test_acc=0.34\n", | |
". loss=0.6136221147742348 train acc=0.84 test_acc=0.35\n", | |
". loss=0.5612616713704905 train acc=0.87 test_acc=0.35\n", | |
". loss=0.5141108998971561 train acc=0.89 test_acc=0.36\n", | |
". loss=0.47157964777228034 train acc=0.91 test_acc=0.35\n", | |
". loss=0.4331703927163727 train acc=0.92 test_acc=0.36\n", | |
". loss=0.3984767365904459 train acc=0.93 test_acc=0.36\n", | |
". loss=0.36698307785564616 train acc=0.95 test_acc=0.36\n", | |
". loss=0.3382974979091719 train acc=0.96 test_acc=0.36\n", | |
". loss=0.31224350181849037 train acc=0.97 test_acc=0.36\n", | |
". loss=0.2882961926872035 train acc=0.98 test_acc=0.36\n", | |
". loss=0.26634790829855504 train acc=0.98 test_acc=0.37\n", | |
". loss=0.2463701948240226 train acc=0.99 test_acc=0.37\n", | |
". loss=0.22817642998624407 train acc=0.99 test_acc=0.37\n", | |
". loss=0.2116415905962718 train acc=0.99 test_acc=0.36\n", | |
". loss=0.1965810596863039 train acc=1.00 test_acc=0.36\n", | |
". loss=0.18289383501221998 train acc=1.00 test_acc=0.37\n", | |
". loss=0.17041848231462334 train acc=1.00 test_acc=0.37\n", | |
". loss=0.15906432026566028 train acc=1.00 test_acc=0.37\n", | |
". loss=0.14870840315894923 train acc=1.00 test_acc=0.37\n", | |
". loss=0.13925842538511468 train acc=1.00 test_acc=0.37\n", | |
". loss=0.13059900020695592 train acc=1.00 test_acc=0.37\n", | |
". loss=0.12262367480258937 train acc=1.00 test_acc=0.37\n", | |
". loss=0.11529267988767013 train acc=1.00 test_acc=0.37\n", | |
". loss=0.10856877382024797 train acc=1.00 test_acc=0.38\n", | |
". loss=0.10238162342965251 train acc=1.00 test_acc=0.38\n", | |
". loss=0.09667579303228187 train acc=1.00 test_acc=0.38\n", | |
". loss=0.09142046588410166 train acc=1.00 test_acc=0.38\n", | |
". loss=0.0865707045380719 train acc=1.00 test_acc=0.38\n", | |
". loss=0.08208864900432233 train acc=1.00 test_acc=0.38\n", | |
". loss=0.0779232608443935 train acc=1.00 test_acc=0.38\n", | |
". loss=0.07404319582630571 train acc=1.00 test_acc=0.38\n", | |
". loss=0.07042847961433552 train acc=1.00 test_acc=0.38\n", | |
". loss=0.06706243787911093 train acc=1.00 test_acc=0.38\n", | |
". loss=0.06392477402539103 train acc=1.00 test_acc=0.38\n", | |
". loss=0.060990767865520726 train acc=1.00 test_acc=0.38\n", | |
". loss=0.058246144271951114 train acc=1.00 test_acc=0.38\n" | |
] | |
} | |
], | |
"source": [ | |
"nn = OneLayerNN(n_features=X.shape[1], n_hidden=50, n_classes=10, loss=\"softmax\")\n", | |
"print(\"Acc: {:.2f}\".format(score(nn, X_tr, y_tr)))\n", | |
"for it in range(500):\n", | |
" loss = nn.train(X_tr, y_tr.astype(np.int32))\n", | |
" if it % 10 == 0:\n", | |
" train_acc = score(nn, X_tr, y_tr)\n", | |
" test_acc = score(nn, X_te, y_te)\n", | |
" print(\". loss={} train acc={:.2f} test_acc={:.2f}\".format(loss, train_acc, test_acc))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 1.39861496e-05, 4.02147580e-02, 2.25333152e-02,\n", | |
" 3.22134513e-04, 2.24215233e-03, 2.84148616e-04,\n", | |
" 9.34122035e-01, 2.12915096e-06, 1.69795156e-04,\n", | |
" 9.55454203e-05],\n", | |
" [ 9.52711463e-04, 8.92738578e-04, 2.19926016e-03,\n", | |
" 3.35625624e-03, 7.82014583e-06, 4.78311306e-02,\n", | |
" 3.73392212e-01, 6.60634240e-05, 5.63419290e-01,\n", | |
" 7.88251708e-03],\n", | |
" [ 2.80045843e-03, 3.71939822e-05, 2.95855777e-02,\n", | |
" 1.35453110e-04, 5.95135374e-02, 8.03435162e-01,\n", | |
" 8.27088982e-03, 1.77216817e-04, 7.82432772e-02,\n", | |
" 1.78012337e-02],\n", | |
" [ 1.18992021e-02, 2.56404028e-01, 3.30392047e-03,\n", | |
" 8.34558962e-05, 7.38034134e-04, 6.55556707e-03,\n", | |
" 2.32805030e-03, 5.48619526e-01, 1.69603365e-01,\n", | |
" 4.64851566e-04],\n", | |
" [ 3.30875378e-01, 9.13542683e-03, 1.19233632e-02,\n", | |
" 5.81140531e-01, 1.57772466e-02, 7.44224337e-06,\n", | |
" 3.69221688e-02, 1.28359040e-02, 8.04384261e-07,\n", | |
" 1.38173508e-03]])" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"nn.predict(X_te)[:5]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment