Skip to content

Instantly share code, notes, and snippets.

@alexlenail
Last active January 22, 2019 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexlenail/7a31a0e5e8bb64256b59493d5add5e85 to your computer and use it in GitHub Desktop.
Save alexlenail/7a31a0e5e8bb64256b59493d5add5e85 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test SPAMS.fistaGraph\n",
"\n",
"Docs http://spams-devel.gforge.inria.fr/doc-R/html/doc_spams006.html#sec39"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import pandas as pd\n",
"import networkx as nx\n",
"import scipy\n",
"import scipy.sparse as ssp\n",
"\n",
"import time\n",
"import spams\n",
"\n",
"def flatten(list_of_lists): return [item for sublist in list_of_lists for item in sublist]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## I. Generate synthetic \"easy\" graph and graph signals"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"num_nodes = 15\n",
"p_edge = 0.3\n",
"g1 = nx.fast_gnp_random_graph(num_nodes, p_edge)\n",
"g2 = nx.fast_gnp_random_graph(num_nodes, p_edge)\n",
"\n",
"nx.relabel_nodes(g2, {number: number + 15 for number in g2.nodes}, copy=False)\n",
"\n",
"g = nx.compose(g1, g2)\n",
"\n",
"g.add_edge(1, 25)\n",
"g.add_edge(2, 20)\n",
"g.add_edge(3, 15)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/alex/miniconda3/lib/python3.7/site-packages/networkx/drawing/nx_pylab.py:611: MatplotlibDeprecationWarning: isinstance(..., numbers.Number)\n",
" if cb.is_numlike(alpha):\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"nx.draw_spring(g)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"edgelist = nx.to_pandas_edgelist(g).values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### True signal comes from 0 and its neighbors"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 2, 11, 12, 14]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"signal_nodes = [0]+[n for n in g.neighbors(0)]\n",
"signal_nodes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### First half is class 1, second half is class 0"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0.])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"NUM_POSITIVES = 20\n",
"NUM_NEGATIVES = 20\n",
"y = np.concatenate((np.ones(NUM_POSITIVES), np.zeros(NUM_NEGATIVES)))\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>20</th>\n",
" <th>21</th>\n",
" <th>22</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" <th>25</th>\n",
" <th>26</th>\n",
" <th>27</th>\n",
" <th>28</th>\n",
" <th>29</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.0</td>\n",
" <td>-1.626836</td>\n",
" <td>1.0</td>\n",
" <td>-1.667516</td>\n",
" <td>-0.108261</td>\n",
" <td>0.390096</td>\n",
" <td>-0.861851</td>\n",
" <td>-0.156386</td>\n",
" <td>1.577740</td>\n",
" <td>1.355505</td>\n",
" <td>...</td>\n",
" <td>0.494552</td>\n",
" <td>0.357644</td>\n",
" <td>-0.999935</td>\n",
" <td>0.391030</td>\n",
" <td>-0.382606</td>\n",
" <td>1.747134</td>\n",
" <td>-0.019657</td>\n",
" <td>2.119312</td>\n",
" <td>0.032667</td>\n",
" <td>-0.146469</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.0</td>\n",
" <td>0.362366</td>\n",
" <td>1.0</td>\n",
" <td>-0.635774</td>\n",
" <td>-0.836827</td>\n",
" <td>-0.276659</td>\n",
" <td>-0.809957</td>\n",
" <td>0.002832</td>\n",
" <td>0.183393</td>\n",
" <td>1.224518</td>\n",
" <td>...</td>\n",
" <td>0.983904</td>\n",
" <td>0.622035</td>\n",
" <td>-0.619685</td>\n",
" <td>0.268317</td>\n",
" <td>-1.240670</td>\n",
" <td>-1.471808</td>\n",
" <td>0.637272</td>\n",
" <td>0.492749</td>\n",
" <td>0.539597</td>\n",
" <td>-0.780348</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>0.126977</td>\n",
" <td>1.0</td>\n",
" <td>-0.112256</td>\n",
" <td>1.780466</td>\n",
" <td>0.342744</td>\n",
" <td>-0.631568</td>\n",
" <td>-1.072373</td>\n",
" <td>-2.127742</td>\n",
" <td>3.137458</td>\n",
" <td>...</td>\n",
" <td>0.010640</td>\n",
" <td>0.457412</td>\n",
" <td>-0.346469</td>\n",
" <td>0.837266</td>\n",
" <td>0.216169</td>\n",
" <td>-1.255527</td>\n",
" <td>2.048838</td>\n",
" <td>-1.903665</td>\n",
" <td>1.030514</td>\n",
" <td>1.808216</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.0</td>\n",
" <td>-1.144054</td>\n",
" <td>1.0</td>\n",
" <td>-0.673975</td>\n",
" <td>-0.488836</td>\n",
" <td>-0.597826</td>\n",
" <td>-0.467641</td>\n",
" <td>-0.079044</td>\n",
" <td>-0.620556</td>\n",
" <td>-0.000212</td>\n",
" <td>...</td>\n",
" <td>-1.718976</td>\n",
" <td>-1.400485</td>\n",
" <td>0.870043</td>\n",
" <td>-0.137372</td>\n",
" <td>-0.695848</td>\n",
" <td>-0.011468</td>\n",
" <td>-0.990155</td>\n",
" <td>0.890150</td>\n",
" <td>-0.059695</td>\n",
" <td>2.337334</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.0</td>\n",
" <td>-1.383579</td>\n",
" <td>1.0</td>\n",
" <td>-0.772876</td>\n",
" <td>-0.020679</td>\n",
" <td>1.385638</td>\n",
" <td>-0.299431</td>\n",
" <td>-0.074756</td>\n",
" <td>-0.465050</td>\n",
" <td>0.641146</td>\n",
" <td>...</td>\n",
" <td>-0.083365</td>\n",
" <td>0.821267</td>\n",
" <td>0.327011</td>\n",
" <td>1.879729</td>\n",
" <td>0.859166</td>\n",
" <td>-1.927820</td>\n",
" <td>-0.632995</td>\n",
" <td>-0.974812</td>\n",
" <td>1.312887</td>\n",
" <td>0.643763</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 30 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 \\\n",
"0 1.0 -1.626836 1.0 -1.667516 -0.108261 0.390096 -0.861851 -0.156386 \n",
"1 1.0 0.362366 1.0 -0.635774 -0.836827 -0.276659 -0.809957 0.002832 \n",
"2 1.0 0.126977 1.0 -0.112256 1.780466 0.342744 -0.631568 -1.072373 \n",
"3 1.0 -1.144054 1.0 -0.673975 -0.488836 -0.597826 -0.467641 -0.079044 \n",
"4 1.0 -1.383579 1.0 -0.772876 -0.020679 1.385638 -0.299431 -0.074756 \n",
"\n",
" 8 9 ... 20 21 22 23 \\\n",
"0 1.577740 1.355505 ... 0.494552 0.357644 -0.999935 0.391030 \n",
"1 0.183393 1.224518 ... 0.983904 0.622035 -0.619685 0.268317 \n",
"2 -2.127742 3.137458 ... 0.010640 0.457412 -0.346469 0.837266 \n",
"3 -0.620556 -0.000212 ... -1.718976 -1.400485 0.870043 -0.137372 \n",
"4 -0.465050 0.641146 ... -0.083365 0.821267 0.327011 1.879729 \n",
"\n",
" 24 25 26 27 28 29 \n",
"0 -0.382606 1.747134 -0.019657 2.119312 0.032667 -0.146469 \n",
"1 -1.240670 -1.471808 0.637272 0.492749 0.539597 -0.780348 \n",
"2 0.216169 -1.255527 2.048838 -1.903665 1.030514 1.808216 \n",
"3 -0.695848 -0.011468 -0.990155 0.890150 -0.059695 2.337334 \n",
"4 0.859166 -1.927820 -0.632995 -0.974812 1.312887 0.643763 \n",
"\n",
"[5 rows x 30 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.DataFrame(np.random.normal(size=(NUM_POSITIVES+NUM_NEGATIVES,30)))\n",
"data.loc[0:NUM_POSITIVES, signal_nodes] = 1\n",
"data.loc[NUM_POSITIVES:NUM_POSITIVES+NUM_NEGATIVES, signal_nodes] = -1\n",
"\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>20</th>\n",
" <th>21</th>\n",
" <th>22</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" <th>25</th>\n",
" <th>26</th>\n",
" <th>27</th>\n",
" <th>28</th>\n",
" <th>29</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>35</th>\n",
" <td>-1.0</td>\n",
" <td>0.711782</td>\n",
" <td>-1.0</td>\n",
" <td>-0.300740</td>\n",
" <td>0.771417</td>\n",
" <td>0.419459</td>\n",
" <td>0.134785</td>\n",
" <td>-1.810631</td>\n",
" <td>1.698634</td>\n",
" <td>0.089898</td>\n",
" <td>...</td>\n",
" <td>0.161639</td>\n",
" <td>0.903178</td>\n",
" <td>-0.773467</td>\n",
" <td>0.145566</td>\n",
" <td>0.344990</td>\n",
" <td>0.548841</td>\n",
" <td>0.444323</td>\n",
" <td>0.594573</td>\n",
" <td>0.014439</td>\n",
" <td>0.111183</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36</th>\n",
" <td>-1.0</td>\n",
" <td>-0.363750</td>\n",
" <td>-1.0</td>\n",
" <td>0.319550</td>\n",
" <td>0.582871</td>\n",
" <td>-0.366024</td>\n",
" <td>2.009703</td>\n",
" <td>-0.849986</td>\n",
" <td>0.252170</td>\n",
" <td>-0.349139</td>\n",
" <td>...</td>\n",
" <td>-1.003231</td>\n",
" <td>0.109738</td>\n",
" <td>-0.425068</td>\n",
" <td>-2.085400</td>\n",
" <td>-0.354644</td>\n",
" <td>-0.923705</td>\n",
" <td>1.521942</td>\n",
" <td>-0.776429</td>\n",
" <td>0.755865</td>\n",
" <td>-1.172837</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37</th>\n",
" <td>-1.0</td>\n",
" <td>-0.212570</td>\n",
" <td>-1.0</td>\n",
" <td>0.809884</td>\n",
" <td>-1.436927</td>\n",
" <td>0.728994</td>\n",
" <td>-0.141139</td>\n",
" <td>-0.658784</td>\n",
" <td>0.074903</td>\n",
" <td>1.463954</td>\n",
" <td>...</td>\n",
" <td>-0.643148</td>\n",
" <td>-0.576813</td>\n",
" <td>-1.599052</td>\n",
" <td>1.233322</td>\n",
" <td>-1.907276</td>\n",
" <td>1.503759</td>\n",
" <td>0.632590</td>\n",
" <td>-0.616824</td>\n",
" <td>1.147559</td>\n",
" <td>1.531823</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38</th>\n",
" <td>-1.0</td>\n",
" <td>-0.561549</td>\n",
" <td>-1.0</td>\n",
" <td>0.625856</td>\n",
" <td>-0.822934</td>\n",
" <td>-0.979575</td>\n",
" <td>0.199540</td>\n",
" <td>-1.276706</td>\n",
" <td>0.575109</td>\n",
" <td>-0.314934</td>\n",
" <td>...</td>\n",
" <td>-1.487596</td>\n",
" <td>0.377102</td>\n",
" <td>1.008312</td>\n",
" <td>0.622991</td>\n",
" <td>0.523273</td>\n",
" <td>-1.566366</td>\n",
" <td>0.030276</td>\n",
" <td>0.092158</td>\n",
" <td>1.426438</td>\n",
" <td>1.271735</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39</th>\n",
" <td>-1.0</td>\n",
" <td>-1.453606</td>\n",
" <td>-1.0</td>\n",
" <td>1.025963</td>\n",
" <td>-0.199703</td>\n",
" <td>0.196850</td>\n",
" <td>-0.033403</td>\n",
" <td>-0.131592</td>\n",
" <td>0.952464</td>\n",
" <td>-1.219901</td>\n",
" <td>...</td>\n",
" <td>-0.476985</td>\n",
" <td>1.392822</td>\n",
" <td>-1.000271</td>\n",
" <td>-0.647108</td>\n",
" <td>0.612353</td>\n",
" <td>1.975893</td>\n",
" <td>-0.744764</td>\n",
" <td>-1.440025</td>\n",
" <td>-0.248124</td>\n",
" <td>1.060998</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 30 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 \\\n",
"35 -1.0 0.711782 -1.0 -0.300740 0.771417 0.419459 0.134785 -1.810631 \n",
"36 -1.0 -0.363750 -1.0 0.319550 0.582871 -0.366024 2.009703 -0.849986 \n",
"37 -1.0 -0.212570 -1.0 0.809884 -1.436927 0.728994 -0.141139 -0.658784 \n",
"38 -1.0 -0.561549 -1.0 0.625856 -0.822934 -0.979575 0.199540 -1.276706 \n",
"39 -1.0 -1.453606 -1.0 1.025963 -0.199703 0.196850 -0.033403 -0.131592 \n",
"\n",
" 8 9 ... 20 21 22 23 \\\n",
"35 1.698634 0.089898 ... 0.161639 0.903178 -0.773467 0.145566 \n",
"36 0.252170 -0.349139 ... -1.003231 0.109738 -0.425068 -2.085400 \n",
"37 0.074903 1.463954 ... -0.643148 -0.576813 -1.599052 1.233322 \n",
"38 0.575109 -0.314934 ... -1.487596 0.377102 1.008312 0.622991 \n",
"39 0.952464 -1.219901 ... -0.476985 1.392822 -1.000271 -0.647108 \n",
"\n",
" 24 25 26 27 28 29 \n",
"35 0.344990 0.548841 0.444323 0.594573 0.014439 0.111183 \n",
"36 -0.354644 -0.923705 1.521942 -0.776429 0.755865 -1.172837 \n",
"37 -1.907276 1.503759 0.632590 -0.616824 1.147559 1.531823 \n",
"38 0.523273 -1.566366 0.030276 0.092158 1.426438 1.271735 \n",
"39 0.612353 1.975893 -0.744764 -1.440025 -0.248124 1.060998 \n",
"\n",
"[5 rows x 30 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.tail()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"edge_weights = [1]*len(edgelist)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Groups are assigned as the one-hop neighborhood of every node"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"30"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"neighborhoods = [[i]+[n for n in g.neighbors(node)] for i, node in enumerate(g.nodes)]\n",
"num_groups = len(neighborhoods)\n",
"num_groups"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## II. Test SPAMS.fistaGraph on \"easy\" graph"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Name: spams.fistaGraph\n",
"#\n",
"# Description:\n",
"# spams.fistaGraph solves sparse regularized problems.\n",
"# X is a design matrix of size m x p\n",
"# X=[x^1,...,x^n]', where the x_i's are the rows of X\n",
"# Y=[y^1,...,y^n] is a matrix of size m x n\n",
"# \n",
"# It implements the algorithms FISTA, ISTA and subgradient descent for solving\n",
"# \n",
"# min_W loss(W) + lambda1 psi(W)\n",
"# \n",
"# The function psi are those used by spams.proximalGraph (see documentation)\n",
"# for the loss functions, see the documentation of spams.fistaFlat\n",
"# \n",
"# This function can also handle intercepts (last row of W is not regularized),\n",
"# and/or non-negativity constraints on W.\n",
"#"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# graph: struct\n",
"# with three fields, eta_g, groups, and groups_var\n",
"# \n",
"# The first fields sets the weights for every group\n",
"# graph.eta_g double N vector \n",
"\n",
"eta_g = np.ones(num_groups)\n",
" \n",
"# The next field sets inclusion relations between groups (but not between groups and variables):\n",
"# graph.groups sparse (double or boolean) N x N matrix \n",
"# the (i,j) entry is non-zero if and only if i is different than j and \n",
"# gi is included in gj.\n",
"\n",
"groups = scipy.sparse.csc_matrix(np.zeros((num_groups,num_groups)),dtype=np.bool)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"i, j = zip(*flatten([[(i, j) for j in neighbors] for i, neighbors in enumerate(neighborhoods)]))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# The next field sets inclusion relations between groups and variables\n",
"# graph.groups_var sparse (double or boolean) p x N matrix\n",
"# the (i,j) entry is non-zero if and only if the variable i is included \n",
"# in gj, but not in any children of gj.\n",
"\n",
"# scipy.sparse.csc_matrix((data, (row_ind, col_ind)), [shape=(M, N)])\n",
"# where data, row_ind and col_ind satisfy the relationship a[row_ind[k], col_ind[k]] = data[k].\n",
"\n",
"groups_var = scipy.sparse.csc_matrix((np.ones(len(i)),(i,j)),dtype=np.bool)\n",
"\n",
"# graph: struct\n",
"# with three fields, eta_g, groups, and groups_var\n",
"# \n",
"graph = {'eta_g':eta_g,'groups':groups,'groups_var':groups_var}"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph['eta_g']"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<30x30 sparse matrix of type '<class 'numpy.bool_'>'\n",
"\twith 0 stored elements in Compressed Sparse Column format>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph['groups']"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<30x30 sparse matrix of type '<class 'numpy.bool_'>'\n",
"\twith 152 stored elements in Compressed Sparse Column format>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph['groups_var']"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Usage: spams.fistaGraph( Y,\n",
"# X,\n",
"# W0,\n",
"# graph,\n",
"# return_optim_info=False,\n",
"# numThreads=-1,\n",
"# max_it=1000,\n",
"# L0=1.0,\n",
"# fixed_step=False,\n",
"# gamma=1.5,\n",
"# lambda1=1.0,\n",
"# lambda2=0.,\n",
"# lambda3=0.,\n",
"# a=1.0,\n",
"# b=0.,\n",
"# tol=0.000001,\n",
"# it0=100,\n",
"# compute_gram=False,\n",
"# intercept=False,\n",
"# regul=\"\",\n",
"# loss=\"\",\n",
"# verbose=False,\n",
"# pos=False,\n",
"# ista=False,\n",
"# subgrad=False,\n",
"# linesearch_mode=0) \n",
"#\n",
"# Inputs:\n",
"# Y : double dense m x n matrix\n",
"\n",
"Y = np.asfortranarray(np.expand_dims(y, axis=1)).astype(float)\n",
"Y = spams.normalize(Y)\n",
"\n",
"# X : double dense or sparse m x p matrix\n",
"\n",
"X = np.asfortranarray(data.values).astype(float)\n",
"X = spams.normalize(X)\n",
"\n",
"# W0 : double dense p x n matrix or p x Nn matrix for multi-logistic loss initial guess\n",
"\n",
"W0 = np.zeros((X.shape[1],Y.shape[1]),dtype=np.float64,order=\"F\")\n",
"\n",
"# graph : struct see documentation of proximalGraph\n",
"# return_optim_info : if true the function will return a tuple of matrices.\n",
"# loss : choice of loss, see above\n",
"# regul : choice of regularization, see below\n",
"# lambda1 : regularization parameter\n",
"# lambda2 : regularization parameter, 0 by default\n",
"# lambda3 : regularization parameter, 0 by default\n",
"# verbose : verbosity level, false by default\n",
"# pos : adds positivity constraints on the coefficients, false by default\n",
"# numThreads : number of threads for exploiting multi-core / multi-cpus. By default, it takes the value -1, which automatically selects all the available CPUs/cores.\n",
"# max_it : maximum number of iterations, 100 by default\n",
"# it0 : frequency for computing duality gap, every 10 iterations by default\n",
"# tol : tolerance for stopping criteration, which is a relative duality gap if it is available, or a relative change of parameters.\n",
"# gamma : multiplier for increasing the parameter L in fista, 1.5 by default\n",
"# L0 : initial parameter L in fista, 0.1 by default, should be small enough\n",
"# fixed_step : deactive the line search for L in fista and use L0 instead\n",
"# compute_gram : pre-compute X^TX, false by default.\n",
"# intercept : do not regularize last row of W, false by default.\n",
"# ista : use ista instead of fista, false by default.\n",
"# subgrad : if not ista, use subradient descent instead of fista, false by default.\n",
"# a :\n",
"# b : if subgrad, the gradient step is a/(t+b) also similar options as proximalTree\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# Regularizers: \n",
"# Given an input matrix U=[u^1,\\ldots,u^n], it computes a matrix V=[v^1,\\ldots,v^n] such that\n",
"#\n",
"# if one chooses a regularization functions on vectors, it computes for each column u of U, a column v of V solving\n",
"# if regul='l0' argmin 0.5||u-v||_2^2 + lambda1||v||_0\n",
"# if regul='l1' argmin 0.5||u-v||_2^2 + lambda1||v||_1\n",
"# if regul='l2' argmin 0.5||u-v||_2^2 + 0.5lambda1||v||_2^2\n",
"# if regul='elastic-net' argmin 0.5||u-v||_2^2 + lambda1||v||_1 + lambda1_2||v||_2^2\n",
"# if regul='fused-lasso' argmin 0.5||u-v||_2^2 + lambda1 FL(v) + lambda1_2||v||_1 + lambda1_3||v||_2^2\n",
"# if regul='linf' argmin 0.5||u-v||_2^2 + lambda1||v||_inf\n",
"# if regul='l1-constraint' argmin 0.5||u-v||_2^2 s.t. ||v||_1 <= lambda1\n",
"# if regul='l2-not-squared' argmin 0.5||u-v||_2^2 + lambda1||v||_2\n",
"# if regul='group-lasso-l2' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_2 where the groups are either defined by groups or by size_group,\n",
"# if regul='group-lasso-linf' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_inf\n",
"# if regul='sparse-group-lasso-l2' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_2 + lambda1_2 ||v||_1 where the groups are either defined by groups or by size_group,\n",
"# if regul='sparse-group-lasso-linf' argmin 0.5||u-v||_2^2 + lambda1 sum_g ||v_g||_inf + lambda1_2 ||v||_1\n",
"# if regul='trace-norm-vec' argmin 0.5||u-v||_2^2 + lambda1 ||mat(v)||_* where mat(v) has size_group rows\n",
"#\n",
"# if regul='graph' argmin 0.5||u-v||_2^2 + lambda1\\sum_{g \\in G} \\eta_g||v_g||_inf\n",
"# if regul='graph+ridge' argmin 0.5||u-v||_2^2 + lambda1\\sum_{g \\in G} \\eta_g||v_g||_inf + lambda1_2||v||_2^2\n",
"#\n",
"# if one chooses a regularization function on matrices\n",
"# if regul='l1l2', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/2}\n",
"# if regul='l1linf', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/inf}\n",
"# if regul='l1l2+l1', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/2} + lambda1_2||V||_{1/1}\n",
"# if regul='l1linf+l1', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/inf} + lambda1_2||V||_{1/1}\n",
"# if regul='l1linf+row-column', V= argmin 0.5||U-V||_F^2 + lambda1||V||_{1/inf} + lambda1_2||V'||_{1/inf}\n",
"# if regul='trace-norm', V= argmin 0.5||U-V||_F^2 + lambda1||V||_*\n",
"# if regul='rank', V= argmin 0.5||U-V||_F^2 + lambda1 rank(V)\n",
"# if regul='none', V= argmin 0.5||U-V||_F^2\n",
"#\n",
"# if regul='multi-task-graph' V=argmin 0.5||U-V||_F^2 + lambda1 \\sum_{i=1}^n\\sum_{g \\in G} \\eta_g||v^i_g||_inf + lambda1_2 \\sum_{g \\in G} \\eta_g max_{j in g}||V_j||_{inf}\n",
"#\n",
"# for all these regularizations, it is possible to enforce non-negativity constraints\n",
"# with the option pos, and to prevent the last row of U to be regularized, with\n",
"# the option intercept\n",
"\n",
"# Note:\n",
"# Valid values for the regularization parameter (regul) for fistaGraph (beyond those listed above) are:\n",
"# \"tree-l0\"\n",
"# \"tree-l2\"\n",
"# \"tree-linf\"\n",
"# \"graph-l2\",\n",
"# \"multi-task-tree\"\n",
"# \"rank-vec\"\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Loss: \n",
"# - if loss='square' and regul is a regularization function for vectors,\n",
"# the entries of Y are real-valued, W = [w^1,...,w^n] is a matrix of size p x n\n",
"# For all column y of Y, it computes a column w of W such that\n",
"# w = argmin 0.5||y- X w||_2^2 + lambda1 psi(w)\n",
"# \n",
"# - if loss='square' and regul is a regularization function for matrices\n",
"# the entries of Y are real-valued, W is a matrix of size p x n. \n",
"# It computes the matrix W such that\n",
"# W = argmin 0.5||Y- X W||_F^2 + lambda1 psi(W)\n",
"# \n",
"# - loss='square-missing' same as loss='square', but handles missing data\n",
"# represented by NaN (not a number) in the matrix Y\n",
"# \n",
"# - if loss='logistic' and regul is a regularization function for vectors,\n",
"# the entries of Y are either -1 or +1, W = [w^1,...,w^n] is a matrix of size p x n\n",
"# For all column y of Y, it computes a column w of W such that\n",
"# w = argmin (1/m)sum_{j=1}^m log(1+e^(-y_j x^j' w)) + lambda1 psi(w),\n",
"# where x^j is the j-th row of X.\n",
"# \n",
"# - if loss='logistic' and regul is a regularization function for matrices\n",
"# the entries of Y are either -1 or +1, W is a matrix of size p x n\n",
"# W = argmin sum_{i=1}^n(1/m)sum_{j=1}^m log(1+e^(-y^i_j x^j' w^i)) + lambda1 psi(W)\n",
"# \n",
"# - if loss='multi-logistic' and regul is a regularization function for vectors,\n",
"# the entries of Y are in {0,1,...,N} where N is the total number of classes\n",
"# W = [W^1,...,W^n] is a matrix of size p x Nn, each submatrix W^i is of size p x N\n",
"# for all submatrix WW of W, and column y of Y, it computes\n",
"# WW = argmin (1/m)sum_{j=1}^m log(sum_{j=1}^r e^(x^j'(ww^j-ww^{y_j}))) + lambda1 sum_{j=1}^N psi(ww^j),\n",
"# where ww^j is the j-th column of WW.\n",
"# \n",
"# - if loss='multi-logistic' and regul is a regularization function for matrices,\n",
"# the entries of Y are in {0,1,...,N} where N is the total number of classes\n",
"# W is a matrix of size p x N, it computes\n",
"# W = argmin (1/m)sum_{j=1}^m log(sum_{j=1}^r e^(x^j'(w^j-w^{y_j}))) + lambda1 psi(W)\n",
"# where ww^j is the j-th column of WW.\n",
"# \n",
"# - loss='cur' useful to perform sparse CUR matrix decompositions, \n",
"# W = argmin 0.5||Y-X*W*X||_F^2 + lambda1 psi(W)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean loss: 0.094094, mean relative duality_gap: -1.000000, time: 0.002162, number of iterations: 482.000000\n"
]
}
],
"source": [
"verbose = True\n",
"lambda1 = 0 # regularization term (no regularization)\n",
"max_it = 100 # maximum number of iterations\n",
"L0 = 0.1\n",
"tol = 1e-5\n",
"intercept = False\n",
"pos = False\n",
"compute_gram = True\n",
"\n",
"loss = 'square'\n",
"regul = 'none'\n",
"tic = time.time()\n",
"\n",
"(W, optim_info) = spams.fistaGraph(Y, X, W0, graph, return_optim_info=True, loss=loss, regul=regul, verbose=verbose)\n",
"\n",
"tac = time.time()\n",
"t = tac - tic\n",
"\n",
"print('mean loss: %f, mean relative duality_gap: %f, time: %f, number of iterations: %f' %(np.mean(optim_info[0,:]),np.mean(optim_info[2,:]),t,np.mean(optim_info[3,:])))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean loss: 0.500000, mean relative duality_gap: 0.000000, time: 0.003627, number of iterations: 100.000000\n"
]
}
],
"source": [
"verbose = True\n",
"lambda1 = 0 # regularization term (no regularization)\n",
"max_it = 100 # maximum number of iterations\n",
"L0 = 0.1\n",
"tol = 1e-5\n",
"intercept = False\n",
"pos = False\n",
"compute_gram = True\n",
"\n",
"loss = 'square'\n",
"regul = 'graph'\n",
"tic = time.time()\n",
"\n",
"(W, optim_info) = spams.fistaGraph(Y, X, W0, graph, return_optim_info=True, loss=loss, regul=regul, verbose=verbose)\n",
"\n",
"tac = time.time()\n",
"t = tac - tic\n",
"\n",
"print('mean loss: %f, mean relative duality_gap: %f, time: %f, number of iterations: %f' %(np.mean(optim_info[0,:]),np.mean(optim_info[2,:]),t,np.mean(optim_info[3,:])))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# Output:\n",
"# W: double dense p x n matrix or p x Nn matrix (for multi-logistic loss)\n",
"# optim: optional, double dense 4 x n matrix.\n",
"# first row: values of the objective functions.\n",
"# third row: values of the relative duality gap (if available)\n",
"# fourth row: number of iterations\n",
"# optim_info: vector of size 4, containing information of the optimization.\n",
"# W = spams.fistaGraph(Y,X,W0,graph,return_optim_info = False,...)\n",
"# (W,optim_info) = spams.fistaGraph(Y,X,W0,graph,return_optim_info = True,...)\n",
"#"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 2, 11, 12, 14]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"signal_nodes"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>20</th>\n",
" <th>21</th>\n",
" <th>22</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" <th>25</th>\n",
" <th>26</th>\n",
" <th>27</th>\n",
" <th>28</th>\n",
" <th>29</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.0</td>\n",
" <td>-0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.0</td>\n",
" <td>-0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.0</td>\n",
" <td>-0.0</td>\n",
" <td>-0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>-0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>-0.0</td>\n",
" <td>-0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 30 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 8 9 ... 20 21 22 23 \\\n",
"0 0.0 -0.0 0.0 -0.0 -0.0 0.0 -0.0 -0.0 -0.0 0.0 ... 0.0 -0.0 0.0 0.0 \n",
"\n",
" 24 25 26 27 28 29 \n",
"0 -0.0 -0.0 0.0 0.0 0.0 0.0 \n",
"\n",
"[1 rows x 30 columns]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W = pd.DataFrame(W)\n",
"W.T"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 False\n",
"dtype: bool"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W[W != 0].any()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment