Skip to content

Instantly share code, notes, and snippets.

@vene
Created July 8, 2016 20:53
Show Gist options
  • Save vene/8bfee84a1c7993f9cacdea6f30ad8896 to your computer and use it in GitHub Desktop.
Save vene/8bfee84a1c7993f9cacdea6f30ad8896 to your computer and use it in GitHub Desktop.
sparsemax loss for Theano
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Author: Vlad Niculae <vlad@vene.ro>\n",
"# License: matching sparsemax_theano.py from https://github.com/Unbabel/sparsemax\n",
"\n",
"import numpy as np\n",
"import theano.tensor as T\n",
"import theano as tn\n",
"from theano import gof\n",
"\n",
"import sparsemax_theano"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class SparsemaxLossGrad(gof.Op):\n",
" \n",
" __props__ = ()\n",
" \n",
" itypes = [T.dmatrix, T.ivector]\n",
" otypes = [T.dmatrix]\n",
" \n",
" def perform(self, node, input_storage, output_storage):\n",
" sm, y = input_storage\n",
" sm[np.arange(sm.shape[0]), y] -= 1\n",
" output_storage[0][0] = sm\n",
" \n",
" def grad(self, *args):\n",
" raise NotImplementedError()\n",
" \n",
" def infer_shape(self, node, shape):\n",
" return [shape[0]]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sparsemax_loss_grad = SparsemaxLossGrad()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",
"class SparsemaxLoss(gof.Op):\n",
"\n",
" __props__ = ()\n",
" itypes = [T.dmatrix, T.ivector]\n",
" otypes = [T.dvector]\n",
"\n",
" def perform(self, node, input_storage, output_storage):\n",
" x, y = input_storage\n",
" \n",
" loss = -x[np.arange(x.shape[0]), y]\n",
" \n",
" for i in range(x.shape[0]):\n",
" sm, tau, _ = sparsemax_theano.project_onto_simplex(x[i])\n",
" loss[i] += 0.5 * np.sum(x[i][np.abs(sm) > 1e-12] ** 2 - tau ** 2)\n",
" \n",
" loss += 0.5\n",
" output_storage[0][0] = loss\n",
"\n",
" def grad(self, inp, grads):\n",
" x, y = inp\n",
" g_sm = grads[0]\n",
" sm = sparsemax_theano.sparsemax(x)\n",
" g = sparsemax_loss_grad(sm, y)\n",
" return [g, T.zeros_like(y, dtype=tn.config.floatX)]\n",
"\n",
" def R_op(self, inputs, eval_points):\n",
" if None in eval_points:\n",
" return [None]\n",
" return self.grad(inputs, eval_points)\n",
"\n",
" def infer_shape(self, node, shapes):\n",
" x_shape, y_shape = shapes\n",
" return [(x_shape[0],)]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sparsemax_loss = SparsemaxLoss()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1. 0. 0.]\n",
"[[ 1. 0. -1.]\n",
" [ 0. 0. 0.]\n",
" [ 0. 0. 0.]]\n"
]
}
],
"source": [
"Z = np.array([[3, 1, 2], [1, 3, 2], [1, 3, 2]], dtype=np.double)\n",
"ytr = np.array([2, 1, 1], dtype=np.int32)\n",
"\n",
"Z_tn = T.dmatrix()\n",
"y_tn = T.ivector()\n",
"\n",
"loss = sparsemax_loss(Z_tn, y_tn)\n",
"print(loss.eval({Z_tn: Z, y_tn: ytr}))\n",
"\n",
"\n",
"h = T.grad(cost=loss.mean(), wrt=Z_tn)\n",
"\n",
"print(h.eval({Z_tn: Z, y_tn: ytr}))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.datasets import make_classification\n",
"from sklearn.cross_validation import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X, y = make_classification(n_samples=5000, n_features=300, n_informative=100,\n",
" n_classes=10, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.1, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from lasagne.layers import InputLayer, DenseLayer, get_output, get_all_params\n",
"from lasagne.nonlinearities import softmax, rectify\n",
"from lasagne.objectives import categorical_crossentropy\n",
"from lasagne.updates import adam\n",
"\n",
"class OneLayerNN(object):\n",
" def __init__(self, n_features, n_hidden, n_classes, loss='softmax'):\n",
" \n",
" X = InputLayer(shape=(None, n_features))\n",
" y = T.ivector()\n",
"\n",
" hidden = DenseLayer(X, n_hidden, nonlinearity=rectify)\n",
" final = DenseLayer(hidden, n_classes, nonlinearity=None)\n",
" \n",
" out = get_output(final)\n",
" if loss == 'softmax':\n",
" y_pred = softmax(out)\n",
" loss = categorical_crossentropy(y_pred, y).mean()\n",
" \n",
" elif loss == 'sparsemax':\n",
" y_pred = sparsemax_theano.sparsemax(out)\n",
" loss = sparsemax_loss(out, y).mean()\n",
" self.train = tn.function([X.input_var, y],\n",
" loss,\n",
" updates=adam(loss,\n",
" get_all_params(final)))\n",
" \n",
" self.predict = tn.function([X.input_var], y_pred) "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def score(nn, X, y):\n",
" return np.mean(np.argmax(nn.predict(X), axis=1) == y)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"nn = OneLayerNN(n_features=X.shape[1], n_hidden=50, n_classes=10, loss=\"sparsemax\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Acc: 0.11\n",
". loss=7.123902198078001 train acc=0.12 test_acc=0.11\n",
". loss=4.507310067407898 train acc=0.17 test_acc=0.14\n",
". loss=2.9644636319351263 train acc=0.24 test_acc=0.17\n",
". loss=2.019472765280979 train acc=0.31 test_acc=0.23\n",
". loss=1.4372722258354609 train acc=0.38 test_acc=0.24\n",
". loss=1.0558284703777132 train acc=0.46 test_acc=0.26\n",
". loss=0.796143370787702 train acc=0.52 test_acc=0.27\n",
". loss=0.612050470774999 train acc=0.59 test_acc=0.29\n",
". loss=0.4773917717006552 train acc=0.65 test_acc=0.31\n",
". loss=0.37463887846847105 train acc=0.69 test_acc=0.31\n",
". loss=0.29680369993601546 train acc=0.74 test_acc=0.31\n",
". loss=0.23645351396738554 train acc=0.79 test_acc=0.31\n",
". loss=0.1889699959953699 train acc=0.83 test_acc=0.32\n",
". loss=0.15135058949980768 train acc=0.86 test_acc=0.32\n",
". loss=0.12157437481055114 train acc=0.90 test_acc=0.33\n",
". loss=0.09787664604424022 train acc=0.92 test_acc=0.33\n",
". loss=0.0791015285658215 train acc=0.94 test_acc=0.33\n",
". loss=0.06403825974546547 train acc=0.96 test_acc=0.34\n",
". loss=0.051886278683072204 train acc=0.98 test_acc=0.34\n",
". loss=0.04212034399920099 train acc=0.98 test_acc=0.34\n",
". loss=0.03422095929256212 train acc=0.99 test_acc=0.34\n",
". loss=0.027750474473344975 train acc=0.99 test_acc=0.35\n",
". loss=0.022496616959196115 train acc=1.00 test_acc=0.35\n",
". loss=0.0182753664050702 train acc=1.00 test_acc=0.35\n",
". loss=0.014869399895866994 train acc=1.00 test_acc=0.35\n",
". loss=0.012095314696915702 train acc=1.00 test_acc=0.35\n",
". loss=0.00982814442875535 train acc=1.00 test_acc=0.36\n",
". loss=0.007967124522815386 train acc=1.00 test_acc=0.36\n",
". loss=0.006439375318507309 train acc=1.00 test_acc=0.36\n",
". loss=0.005197087487190476 train acc=1.00 test_acc=0.36\n",
". loss=0.004188851208431322 train acc=1.00 test_acc=0.37\n",
". loss=0.003372376354147401 train acc=1.00 test_acc=0.36\n",
". loss=0.0027101630654043007 train acc=1.00 test_acc=0.36\n",
". loss=0.002173077352041258 train acc=1.00 test_acc=0.36\n",
". loss=0.0017359028478273414 train acc=1.00 test_acc=0.36\n",
". loss=0.0013822257157413485 train acc=1.00 test_acc=0.36\n",
". loss=0.0010974183963595831 train acc=1.00 test_acc=0.36\n",
". loss=0.0008689256716532196 train acc=1.00 test_acc=0.36\n",
". loss=0.0006857459089074381 train acc=1.00 test_acc=0.36\n",
". loss=0.0005391318658554301 train acc=1.00 test_acc=0.36\n",
". loss=0.0004227504209984751 train acc=1.00 test_acc=0.36\n",
". loss=0.00033068982348024944 train acc=1.00 test_acc=0.36\n",
". loss=0.00025798704142677634 train acc=1.00 test_acc=0.36\n",
". loss=0.0002007413460167255 train acc=1.00 test_acc=0.36\n",
". loss=0.0001558219465871158 train acc=1.00 test_acc=0.36\n",
". loss=0.00012062046262150352 train acc=1.00 test_acc=0.36\n",
". loss=9.307841869440112e-05 train acc=1.00 test_acc=0.36\n",
". loss=7.16323119402043e-05 train acc=1.00 test_acc=0.36\n",
". loss=5.496665825257861e-05 train acc=1.00 test_acc=0.36\n",
". loss=4.204840809181516e-05 train acc=1.00 test_acc=0.36\n"
]
}
],
"source": [
"print(\"Acc: {:.2f}\".format(score(nn, X_tr, y_tr)))\n",
"for it in range(500):\n",
" loss = nn.train(X_tr, y_tr.astype(np.int32))\n",
" if it % 10 == 0:\n",
" train_acc = score(nn, X_tr, y_tr)\n",
" test_acc = score(nn, X_te, y_te)\n",
" print(\". loss={} train acc={:.2f} test_acc={:.2f}\".format(loss, train_acc, test_acc))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 1. , 0. , 0. , 0. ],\n",
" [ 0. , 1. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ],\n",
" [ 0. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 1. , 0. ],\n",
" [ 0.74288784, 0.02723951, 0. , 0. , 0. ,\n",
" 0. , 0. , 0.22987264, 0. , 0. ],\n",
" [ 1. , 0. , 0. , 0. , 0. ,\n",
" 0. , 0. , 0. , 0. , 0. ]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.predict(X_te)[:5]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Acc: 0.08\n",
". loss=9.01976599621005 train acc=0.09 test_acc=0.09\n",
". loss=5.970537785723535 train acc=0.12 test_acc=0.10\n",
". loss=4.027242153532298 train acc=0.19 test_acc=0.15\n",
". loss=2.906558256956219 train acc=0.28 test_acc=0.20\n",
". loss=2.254216397641683 train acc=0.37 test_acc=0.23\n",
". loss=1.819587603325597 train acc=0.45 test_acc=0.27\n",
". loss=1.5255773090072133 train acc=0.51 test_acc=0.28\n",
". loss=1.311435933493566 train acc=0.58 test_acc=0.30\n",
". loss=1.1463621981329986 train acc=0.63 test_acc=0.31\n",
". loss=1.0153323742794962 train acc=0.68 test_acc=0.32\n",
". loss=0.9075062673000046 train acc=0.72 test_acc=0.32\n",
". loss=0.81651455111882 train acc=0.76 test_acc=0.32\n",
". loss=0.738993701613929 train acc=0.79 test_acc=0.32\n",
". loss=0.6722895362147565 train acc=0.81 test_acc=0.34\n",
". loss=0.6136221147742348 train acc=0.84 test_acc=0.35\n",
". loss=0.5612616713704905 train acc=0.87 test_acc=0.35\n",
". loss=0.5141108998971561 train acc=0.89 test_acc=0.36\n",
". loss=0.47157964777228034 train acc=0.91 test_acc=0.35\n",
". loss=0.4331703927163727 train acc=0.92 test_acc=0.36\n",
". loss=0.3984767365904459 train acc=0.93 test_acc=0.36\n",
". loss=0.36698307785564616 train acc=0.95 test_acc=0.36\n",
". loss=0.3382974979091719 train acc=0.96 test_acc=0.36\n",
". loss=0.31224350181849037 train acc=0.97 test_acc=0.36\n",
". loss=0.2882961926872035 train acc=0.98 test_acc=0.36\n",
". loss=0.26634790829855504 train acc=0.98 test_acc=0.37\n",
". loss=0.2463701948240226 train acc=0.99 test_acc=0.37\n",
". loss=0.22817642998624407 train acc=0.99 test_acc=0.37\n",
". loss=0.2116415905962718 train acc=0.99 test_acc=0.36\n",
". loss=0.1965810596863039 train acc=1.00 test_acc=0.36\n",
". loss=0.18289383501221998 train acc=1.00 test_acc=0.37\n",
". loss=0.17041848231462334 train acc=1.00 test_acc=0.37\n",
". loss=0.15906432026566028 train acc=1.00 test_acc=0.37\n",
". loss=0.14870840315894923 train acc=1.00 test_acc=0.37\n",
". loss=0.13925842538511468 train acc=1.00 test_acc=0.37\n",
". loss=0.13059900020695592 train acc=1.00 test_acc=0.37\n",
". loss=0.12262367480258937 train acc=1.00 test_acc=0.37\n",
". loss=0.11529267988767013 train acc=1.00 test_acc=0.37\n",
". loss=0.10856877382024797 train acc=1.00 test_acc=0.38\n",
". loss=0.10238162342965251 train acc=1.00 test_acc=0.38\n",
". loss=0.09667579303228187 train acc=1.00 test_acc=0.38\n",
". loss=0.09142046588410166 train acc=1.00 test_acc=0.38\n",
". loss=0.0865707045380719 train acc=1.00 test_acc=0.38\n",
". loss=0.08208864900432233 train acc=1.00 test_acc=0.38\n",
". loss=0.0779232608443935 train acc=1.00 test_acc=0.38\n",
". loss=0.07404319582630571 train acc=1.00 test_acc=0.38\n",
". loss=0.07042847961433552 train acc=1.00 test_acc=0.38\n",
". loss=0.06706243787911093 train acc=1.00 test_acc=0.38\n",
". loss=0.06392477402539103 train acc=1.00 test_acc=0.38\n",
". loss=0.060990767865520726 train acc=1.00 test_acc=0.38\n",
". loss=0.058246144271951114 train acc=1.00 test_acc=0.38\n"
]
}
],
"source": [
"nn = OneLayerNN(n_features=X.shape[1], n_hidden=50, n_classes=10, loss=\"softmax\")\n",
"print(\"Acc: {:.2f}\".format(score(nn, X_tr, y_tr)))\n",
"for it in range(500):\n",
" loss = nn.train(X_tr, y_tr.astype(np.int32))\n",
" if it % 10 == 0:\n",
" train_acc = score(nn, X_tr, y_tr)\n",
" test_acc = score(nn, X_te, y_te)\n",
" print(\". loss={} train acc={:.2f} test_acc={:.2f}\".format(loss, train_acc, test_acc))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.39861496e-05, 4.02147580e-02, 2.25333152e-02,\n",
" 3.22134513e-04, 2.24215233e-03, 2.84148616e-04,\n",
" 9.34122035e-01, 2.12915096e-06, 1.69795156e-04,\n",
" 9.55454203e-05],\n",
" [ 9.52711463e-04, 8.92738578e-04, 2.19926016e-03,\n",
" 3.35625624e-03, 7.82014583e-06, 4.78311306e-02,\n",
" 3.73392212e-01, 6.60634240e-05, 5.63419290e-01,\n",
" 7.88251708e-03],\n",
" [ 2.80045843e-03, 3.71939822e-05, 2.95855777e-02,\n",
" 1.35453110e-04, 5.95135374e-02, 8.03435162e-01,\n",
" 8.27088982e-03, 1.77216817e-04, 7.82432772e-02,\n",
" 1.78012337e-02],\n",
" [ 1.18992021e-02, 2.56404028e-01, 3.30392047e-03,\n",
" 8.34558962e-05, 7.38034134e-04, 6.55556707e-03,\n",
" 2.32805030e-03, 5.48619526e-01, 1.69603365e-01,\n",
" 4.64851566e-04],\n",
" [ 3.30875378e-01, 9.13542683e-03, 1.19233632e-02,\n",
" 5.81140531e-01, 1.57772466e-02, 7.44224337e-06,\n",
" 3.69221688e-02, 1.28359040e-02, 8.04384261e-07,\n",
" 1.38173508e-03]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nn.predict(X_te)[:5]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.4"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment