Skip to content

Instantly share code, notes, and snippets.

@nirshlezinger1
Created May 4, 2022 07:58
Show Gist options
  • Save nirshlezinger1/7cf8528f22709fb44fc5e13731758c48 to your computer and use it in GitHub Desktop.
Save nirshlezinger1/7cf8528f22709fb44fc5e13731758c48 to your computer and use it in GitHub Desktop.
LISTA vs ISTA.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "LISTA vs ISTA.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyO+zRSJOsEYHpiY0D6Q/Xzw",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nirshlezinger1/7cf8528f22709fb44fc5e13731758c48/lista-vs-ista.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Deep Unfolding\n",
"In this notebook we will compare the iterative soft thresholding algorithm (ISTA) to its deep unfolded version, coined learned ISTA (LISTA). LISTA is in fact the origin of deep unfolding methodology, proposed by Gregor and LeCun back in 2010, and has spurred a multitude of variants over the years. This notebook compares a basic implementation of it to the model-based ISTA, examining the ability of deep unfolding to increase the convergence rate of iterative optimizers. \n"
],
"metadata": {
"id": "jqHEmDXWutMj"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torch.utils.data as Data\n",
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"from scipy.linalg import eigvalsh \n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"torch.manual_seed(0)\n",
"\n",
"torch.set_default_dtype(torch.float64)"
],
"metadata": {
"id": "qgGnY-8TvTo9"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## ISTA\n",
"We begin by recalling the formulation of ISTA. The original algorihtm aims at solving the LASSO problem\n",
"\\begin{equation} \n",
"\t \\hat{\\boldsymbol{s}} = \\mathop{\\arg\\min}\\limits_{\\boldsymbol{s}} \\frac{1}{2}\\|\\boldsymbol{x}-\\boldsymbol{H}\\boldsymbol{s}\\|^2 +\\rho\\|\\boldsymbol{s}\\|_1,\n",
"\\end{equation}\n",
"via the iterative update equations\n",
"\\begin{equation}\n",
" \\boldsymbol{s}^{(k+1)} \\leftarrow \\mathcal{T}_{\\beta=\\mu\\rho}\\left( \\boldsymbol{s}^{(k)} - \\mu \\boldsymbol{H}^T(\\boldsymbol{H}\\boldsymbol{s}^{(k)}-\\boldsymbol{x}) \\right), \n",
"\\end{equation}\n",
"with $\\mathcal{T}$ being the soft-thresholding operation.\n",
"\n",
"Since one can probe convergence of ISTA when the step-size $\\mu$ satisfies $\\mu \\leq \\frac{1}{\\max {\\rm eig}(\\boldsymbol{H}^T\\boldsymbol{H})}$, we will use this value as our default setting of $\\mu = 1/L$ with $L=\\max {\\rm eig}(\\boldsymbol{H}^T\\boldsymbol{H})$."
],
"metadata": {
"id": "dKT5TkAgo5W_"
}
},
{
"cell_type": "code",
"source": [
"def ista(x, H, rho=0.5, L=1, max_itr=300):\n",
" s = torch.zeros(H.shape[1])\n",
" proj = torch.nn.Softshrink(lambd=rho / L)\n",
" for _ in range(max_itr):\n",
" s_tild = s - 1 / L * (H.T @ (H @ s - x))\n",
" s = proj(s_tild)\n",
" return s"
],
"metadata": {
"id": "ejv6JLOf8Rkj"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Data\n",
"Next, we generate a data set $\\mathcal{D}=\\{(\\boldsymbol{s}_t, \\boldsymbol{x}_t)\\}_{t=1}^{n_t}$, where each $\\boldsymbol{s}_t$ has $m=200$ entries, out of which at most $k=4$ are non-zero, while $\\boldsymbol{x}_t$ has $n=150$ entires and is obtained via \n",
"\\begin{equation}\n",
"\\boldsymbol{x}_t = \\boldsymbol{H}\\boldsymbol{s}_t + \\boldsymbol{w}_t,\n",
"\\end{equation}\n",
"with $\\boldsymbol{w}_t$ being i.i.d. Gaussian noise.\n",
"\n",
"To that aim, we generate a dedicated class inheriting Dataset to get the data samples, and to allow each tuple to be comprised of both $(\\boldsymbol{s}_t, \\boldsymbol{x}_t)$ as well as $\\boldsymbol{H}$."
],
"metadata": {
"id": "pb3vwVLThFnF"
}
},
{
"cell_type": "code",
"source": [
"class SimulatedData(Data.Dataset): \n",
" def __init__(self, x, H, s): \n",
" self.x = x\n",
" self.s = s\n",
" self.H = H \n",
"\n",
" def __len__(self):\n",
" return self.x.shape[1]\n",
"\n",
" def __getitem__(self, idx):\n",
" x = self.x[:, idx] \n",
" H = self.H\n",
" s = self.s[:, idx]\n",
" return x, H, s\n"
],
"metadata": {
"id": "lFnR46I-zJN2"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The following is an aid function to convert the noise level from SNR to standard deviation"
],
"metadata": {
"id": "TSLfjEPy2WmE"
}
},
{
"cell_type": "code",
"source": [
"def snr2std(y, snr):\n",
" s_db = 10 * torch.log10(y.var())\n",
" noise_db = s_db - snr\n",
" noise_std = 10 ** (noise_db / 20)\n",
" return noise_std"
],
"metadata": {
"id": "E83tiuJc7OGg"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The next function creates a dataset. To keep things simple, we use as our labeled data the output of the model-based ISTA with $200$ iterations, rather than the actual sparse signal (which may be difficult to retrieve in some sceanrios). "
],
"metadata": {
"id": "HCnvoe7A2bZl"
}
},
{
"cell_type": "code",
"source": [
"def create_data_set(H, n, m, k, N=1000, batch_size=512, snr = 30): \n",
" # The maximal eigenvalue\n",
" L = float(eigvalsh(H.t() @ H, eigvals=(m - 1, m - 1)))\n",
" # Initialization\n",
" x = torch.zeros(n, N) \n",
" s = torch.zeros(m, N)\n",
" # Create signals\n",
" for i in range(N): \n",
" # Random k indices and values\n",
" s_ind = torch.randperm(m)[:k]\n",
" s_nz = torch.randn(k) \n",
" # Create the signal\n",
" x[:, i] = H[:, s_ind] @ s_nz\n",
" # Add noise to the signal \n",
" noise = snr2std(x[:, i], snr) * torch.randn(n)\n",
" x[:, i] += noise\n",
" # Create the sparse representation by solving the Basis-Pursuit\n",
" s[:, i] = ista(x=x[:, i], H=H, L=L)\n",
" #x[x_ind, i] = x_nz\n",
"\n",
"\n",
" simulated = SimulatedData(x=x, H=H, s=s)\n",
" data_loader = Data.DataLoader(dataset=simulated, batch_size=batch_size, shuffle=True)\n",
" return data_loader"
],
"metadata": {
"id": "GhUgG1V1mE3x"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Finally, we generate the data set"
],
"metadata": {
"id": "daJJ2Jcl3YiO"
}
},
{
"cell_type": "code",
"source": [
"#n, m, k = 50, 70, 4\n",
"n, m, k = 150, 200, 4\n",
"# Measurement matrix\n",
"H = torch.randn(n, m)\n",
"H /= torch.norm(H, dim=0) \n",
"\n",
"train_loader = create_data_set(H, n=n, m=m, k=k, N=1000)\n",
"test_loader = create_data_set(H, n=n, m=m, k=k, N=1000)"
],
"metadata": {
"id": "ogzWwvvRnxPu"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's see what the samples look like"
],
"metadata": {
"id": "ll2C2WvcoVka"
}
},
{
"cell_type": "code",
"source": [
"x_exm, _, s_exm =test_loader.dataset.__getitem__(5)\n",
"plt.figure(figsize=(8, 8)) \n",
"plt.subplot(2, 1, 1) \n",
"plt.plot(x_exm, label = 'observation' ) \n",
"plt.xlabel('Index', fontsize=10)\n",
"plt.ylabel('Value', fontsize=10)\n",
"plt.legend( )\n",
"plt.subplot (2, 1, 2) \n",
"plt.plot(s_exm, label = 'sparse signal', color='k') \n",
"plt.xlabel('Index', fontsize=10)\n",
"plt.ylabel('Value', fontsize=10)\n",
"plt.legend( )\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 497
},
"id": "fXnrCnVn3kbs",
"outputId": "c05a0fd9-e6f3-44e5-a460-43f0a632adde"
},
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x576 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Train and Validation\n",
"\n",
"For training, we use SGDM with learning rate scheduler and the $\\ell_2$ loss"
],
"metadata": {
"id": "RHKRaregobdz"
}
},
{
"cell_type": "code",
"source": [
"def train(model, train_loader, valid_loader, num_epochs=30):\n",
" \"\"\"Train a network.\n",
" Returns:\n",
" loss_test {numpy} -- loss function values on test set\n",
" \"\"\"\n",
" # Initialization\n",
" optimizer = torch.optim.SGD(\n",
" model.parameters(),\n",
" lr=5e-05,\n",
" momentum=0.9,\n",
" weight_decay=0,\n",
" )\n",
" scheduler = torch.optim.lr_scheduler.StepLR(\n",
" optimizer, step_size=50, gamma=0.1\n",
" )\n",
" loss_train = np.zeros((num_epochs,))\n",
" loss_test = np.zeros((num_epochs,)) \n",
" # Main loop\n",
" for epoch in range(num_epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for step, (b_x, b_H, b_s) in enumerate(train_loader):\n",
" #b_x, b_H, b_x = b_x.cuda(), b_H.cuda(), b_s.cuda() \n",
" s_hat = model(b_x)\n",
" loss = F.mse_loss(s_hat, b_s, reduction=\"sum\")\n",
" optimizer.zero_grad() \n",
" loss.backward() \n",
" optimizer.step() \n",
" model.zero_grad()\n",
" train_loss += loss.data.item()\n",
" loss_train[epoch] = train_loss / len(train_loader.dataset) \n",
" scheduler.step()\n",
"\n",
" # validation\n",
" model.eval()\n",
" test_loss = 0\n",
" for step, (b_x, b_H, b_s) in enumerate(valid_loader):\n",
" #b_x, b_H, b_x = b_x.cuda(), b_H.cuda(), b_s.cuda()\n",
" s_hat = model(b_x)\n",
" test_loss += F.mse_loss(s_hat, b_s, reduction=\"sum\").data.item()\n",
" loss_test[epoch] = test_loss / len(valid_loader.dataset)\n",
" # Print\n",
" if epoch % 10 == 0:\n",
" print(\n",
" \"Epoch %d, Train loss %.8f, Validation loss %.8f\"\n",
" % (epoch, loss_train[epoch], loss_test[epoch])\n",
" )\n",
"\n",
"\n",
" return loss_test"
],
"metadata": {
"id": "63WxCf_JITSb"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Learned ISTA (LISTA)\n",
"As noted earlier, there are many variants of unfolded ISTA. Here we should use the following update equation:\n",
"\\begin{equation}\n",
" \\boldsymbol{s}^{(k+1)} \\leftarrow \\mathcal{T}_{\\beta^{(k)}}\\left( \\boldsymbol{s}^{(k)} - \\mu^{(k)} (\\boldsymbol{B}\\boldsymbol{s}^{(k)}+\\boldsymbol{A}\\boldsymbol{x}) \\right), \n",
"\\end{equation}\n",
"with $\\boldsymbol{A}, \\boldsymbol{B}$ and $\\{\\beta^{(k)},\\mu^{(k)}\\}$ beging trainable parameters."
],
"metadata": {
"id": "HxrAyYc2pHbz"
}
},
{
"cell_type": "code",
"source": [
"class LISTA_Model(nn.Module):\n",
" def __init__(self, n, m, T=6, rho=1.0, H=None):\n",
" super(LISTA_Model, self).__init__()\n",
" self.n, self.m = n, m\n",
" self.H = H\n",
" self.T = T # ISTA Iterations\n",
" self.rho = rho # Lagrangian Multiplier\n",
" self.A = nn.Linear(n, m, bias=False) # Weight Matrix\n",
" self.B = nn.Linear(m, m, bias=False) # Weight Matrix\n",
" # ISTA Stepsizes eta \n",
" self.beta = nn.Parameter(torch.ones(T + 1, 1, 1), requires_grad=True)\n",
" self.mu = nn.Parameter(torch.ones(T + 1, 1, 1), requires_grad=True)\n",
" # Initialization\n",
" if H is not None:\n",
" self.A.weight.data = H.t()\n",
" self.B.weight.data = H.t() @ H \n",
"\n",
" def _shrink(self, s, beta):\n",
" return beta * F.softshrink(s / beta, lambd=self.rho)\n",
"\n",
" def forward(self, x):\n",
" s = self._shrink(self.mu[0, :, :] * self.A(x), self.beta[0, :, :])\n",
" for i in range(1, self.T + 1): \n",
" s = self._shrink(\n",
" s - self.mu[i, :, :] * self.B(s) + self.mu[i, :, :] * self.A(x),\n",
" self.beta[i, :, :],\n",
" )\n",
" return s\n"
],
"metadata": {
"id": "OWvCFWUC9GuH"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The following functions apply LISTA and ISTA for given data sets"
],
"metadata": {
"id": "iUYvLlGsAMuL"
}
},
{
"cell_type": "code",
"source": [
"def lista_apply(train_loader, test_loader, T, H):\n",
" n = H.shape[1]\n",
" m = H.shape[1]\n",
" lista = LISTA_Model(n=n, m=m, T=T, H=H)\n",
" #lista.cuda()\n",
" \n",
" loss_test = train(lista, train_loader, test_loader)\n",
" err_lista = loss_test[-1]\n",
" return err_lista"
],
"metadata": {
"id": "XbErjona589S"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def ista_apply(test_loader, T, H, rho=0.5):\n",
" m = H.shape[1]\n",
" L = float(eigvalsh(H.t() @ H, eigvals=(m - 1, m - 1)))\n",
"\n",
" loss = 0\n",
" for step, (x, _, s) in enumerate(test_loader.dataset):\n",
" s_hat = ista(x=x, H=H, rho=rho, L=L, max_itr=T)\n",
" loss += F.mse_loss(s_hat, s, reduction=\"sum\").data.item()\n",
"\n",
" return loss / len(test_loader.dataset)\n",
" \n"
],
"metadata": {
"id": "YFpYX3FyOXrC"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Comparing LISTA to ISTA\n",
"Finally, we show that by learning the parameterization of ISTA in a per-iteration manner, one can notably imporve its convergesnce rate. To that aim, in the following loop we compare ISTA to trained LISTA with different number of iterations.\n"
],
"metadata": {
"id": "1xMFNzd5zUOJ"
}
},
{
"cell_type": "code",
"source": [
"# Number of unfoldings\n",
"tstart, tend, tstep = 0, 13, 2\n",
"T_opt = range(tstart, tend, tstep)\n",
"\n",
"ista_MSE = []\n",
"lista_MSE = []\n",
"for i in range(len(T_opt)):\n",
" T = T_opt[i]\n",
" # Apply ISTA with T iterations\n",
" ista_MSE.append(ista_apply(test_loader, T, H))\n",
" # Train and apply LISTA with T iterations / layers\n",
" lista_MSE.append(lista_apply(train_loader, test_loader, T, H))\n",
" \n",
"# plot the resutls \n",
"fig = plt.figure()\n",
"plt.plot(T_opt, ista_MSE, label='ISTA', color='b',linewidth=0.5)\n",
"plt.plot(T_opt, lista_MSE, label='LISTA', color='r', linewidth=2) \n",
"plt.xlabel('Number of iterations', fontsize=10)\n",
"plt.ylabel('MSE', fontsize=10)\n",
"plt.yscale(\"log\")\n",
"plt.legend()\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 644
},
"id": "5fLLBxGBrkJP",
"outputId": "60085237-7ddb-4779-cd2f-0fbeec9ed760"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 0, Train loss 0.17134949, Validation loss 0.11539615\n",
"Epoch 10, Train loss 0.02844153, Validation loss 0.03291038\n",
"Epoch 20, Train loss 0.02302526, Validation loss 0.03146555\n",
"Epoch 0, Train loss 0.16294222, Validation loss 0.01027154\n",
"Epoch 10, Train loss 0.00039944, Validation loss 0.00044741\n",
"Epoch 20, Train loss 0.00031676, Validation loss 0.00041448\n",
"Epoch 0, Train loss 0.16005076, Validation loss 0.00928998\n",
"Epoch 10, Train loss 0.00038584, Validation loss 0.00044998\n",
"Epoch 20, Train loss 0.00029110, Validation loss 0.00039879\n",
"Epoch 0, Train loss 0.16261384, Validation loss 0.00728203\n",
"Epoch 10, Train loss 0.00037819, Validation loss 0.00046729\n",
"Epoch 20, Train loss 0.00028347, Validation loss 0.00041146\n",
"Epoch 0, Train loss 0.15724220, Validation loss 0.00659751\n",
"Epoch 10, Train loss 0.00041298, Validation loss 0.00052746\n",
"Epoch 20, Train loss 0.00030126, Validation loss 0.00045046\n",
"Epoch 0, Train loss 0.15557589, Validation loss 0.00644273\n",
"Epoch 10, Train loss 0.00048024, Validation loss 0.00063444\n",
"Epoch 20, Train loss 0.00033723, Validation loss 0.00051873\n",
"Epoch 0, Train loss 0.15442260, Validation loss 0.00535904\n",
"Epoch 10, Train loss 0.00057877, Validation loss 0.00077456\n",
"Epoch 20, Train loss 0.00039724, Validation loss 0.00061686\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment