Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active February 2, 2021 11:00
Show Gist options
  • Save mtreviso/380b9c90a67e25175221f22e6ba84a18 to your computer and use it in GitHub Desktop.
Save mtreviso/380b9c90a67e25175221f22e6ba84a18 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "surprised-workstation",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import entmax\n",
"from cvxpylayers.torch import CvxpyLayer\n",
"import cvxpy as cp"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "vulnerable-person",
"metadata": {},
"outputs": [],
"source": [
"def sparsemax_fw(X, dim=-1):\n",
" max_val, _ = X.max(dim=dim, keepdim=True)\n",
" X = X - max_val # same numerical stability trick as softmax\n",
" tau, supp_size = entmax.activations._sparsemax_threshold_and_support(X, dim=dim)\n",
" return torch.clamp(X - tau, min=0)\n",
"\n",
"def cvx_sparsemax(X):\n",
" max_val, _ = X.max(dim=-1, keepdim=True)\n",
" X = X - max_val # same numerical stability trick as softmax\n",
" # batch dimension is infered by cvxpy and then broadcasted\n",
" n = X.shape[-1]\n",
" x = cp.Parameter(n)\n",
" y = cp.Variable(n)\n",
" obj = cp.sum_squares(x-y)\n",
" cons = [cp.sum(y) == 1, 0. <= y, y <= 1.]\n",
" prob = cp.Problem(cp.Minimize(obj), cons)\n",
" layer = CvxpyLayer(prob, [x], [y])\n",
" out, = layer(X)\n",
" # just to get rid of unstable negative numbers\n",
" return torch.clamp(out, min=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "detected-cruise",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 2\n",
"n = 8"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "pediatric-honor",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.7259, -0.7018, 2.2492, 0.4574, 0.1678, -1.1111, 0.2099, -0.6585],\n",
" [ 1.0499, 1.2296, -1.2662, 0.4773, 0.2722, -0.4636, -0.4626, -0.3327]],\n",
" requires_grad=True)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.randn(batch_size, n)\n",
"X1, X2, X3 = X.clone().requires_grad_(True), X.clone().requires_grad_(True), X.clone().requires_grad_(True)\n",
"X1"
]
},
{
"cell_type": "markdown",
"id": "partial-continuity",
"metadata": {},
"source": [
"## Autograd sparsemax"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "heated-compatibility",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4102, 0.5898, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
" grad_fn=<ClampBackward>)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"h = sparsemax_fw(X1, dim=-1)\n",
"h.retain_grad() # to save grad in .grad\n",
"h"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "engaged-delivery",
"metadata": {},
"outputs": [],
"source": [
"loss = (h ** 2).sum()\n",
"loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "progressive-split",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [-0.1797, 0.1797, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])\n",
"tensor([[0.0000, 0.0000, 2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.8203, 1.1797, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])\n"
]
}
],
"source": [
"print(X1.grad.data)\n",
"print(h.grad.data)"
]
},
{
"cell_type": "markdown",
"id": "blind-technical",
"metadata": {},
"source": [
"## Specialized sparsemax"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "beginning-default",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4102, 0.5898, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
" grad_fn=<SparsemaxFunctionBackward>)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"h2 = entmax.sparsemax(X2, dim=-1)\n",
"h2.retain_grad() # to save grad in .grad\n",
"h2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "organic-environment",
"metadata": {},
"outputs": [],
"source": [
"loss2 = (h2 ** 2).sum()\n",
"loss2.backward()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "pleased-leone",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [-0.1797, 0.1797, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])\n",
"tensor([[0.0000, 0.0000, 2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.8203, 1.1797, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])\n"
]
}
],
"source": [
"print(X2.grad.data)\n",
"print(h2.grad.data)"
]
},
{
"cell_type": "markdown",
"id": "academic-bloom",
"metadata": {},
"source": [
"## cvxpylayers sparsemax"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "southeast-medium",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 8.6379e-09, 0.0000e+00,\n",
" 4.7698e-09, 0.0000e+00],\n",
" [4.1017e-01, 5.8983e-01, 0.0000e+00, 7.6037e-08, 9.7740e-08, 0.0000e+00,\n",
" 0.0000e+00, 1.2660e-08]], grad_fn=<ClampBackward>)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"h3 = cvx_sparsemax(X3) # no dim -> cvxlayers always get 1d or 2d inputs, so dim=-1 always\n",
"h3.retain_grad() # to save grad in .grad\n",
"h3"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "paperback-commercial",
"metadata": {},
"outputs": [],
"source": [
"loss3 = (h3 ** 2).sum()\n",
"loss3.backward() "
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "extended-taxation",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-1.0728e-09, 1.9481e-09, 7.0073e-09, -2.7421e-09, -5.3363e-09,\n",
" 1.5065e-10, -2.0931e-09, 2.1382e-09],\n",
" [-1.7966e-01, 1.7966e-01, 1.1963e-07, -1.5928e-08, 1.5564e-07,\n",
" -1.1809e-07, -1.1841e-07, -9.3386e-08]])\n",
"tensor([[0.0000, 0.0000, 2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.8203, 1.1797, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])\n"
]
}
],
"source": [
"print(X3.grad.data)\n",
"print(h2.grad.data)"
]
},
{
"cell_type": "markdown",
"id": "veterinary-regard",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"id": "parental-mortality",
"metadata": {},
"source": [
"## Performance comparison"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "backed-adrian",
"metadata": {},
"outputs": [],
"source": [
"def small_net(X, prob_fn, dim=None):\n",
" h = prob_fn(X) if dim is None else prob_fn(X, dim=dim)\n",
" h.retain_grad() # to save grad in .grad\n",
" loss = (h ** 2).sum()\n",
" loss.backward()\n",
" return h"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "other-criticism",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4102, 0.5898, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
" grad_fn=<ClampBackward>)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"small_net(X1, sparsemax_fw, dim=-1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "seventh-harassment",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4102, 0.5898, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
" grad_fn=<SparsemaxFunctionBackward>)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"small_net(X2, entmax.sparsemax, dim=-1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "defined-outreach",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 8.6379e-09, 0.0000e+00,\n",
" 4.7698e-09, 0.0000e+00],\n",
" [4.1017e-01, 5.8983e-01, 0.0000e+00, 7.6037e-08, 9.7740e-08, 0.0000e+00,\n",
" 0.0000e+00, 1.2660e-08]], grad_fn=<ClampBackward>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"small_net(X3, cvx_sparsemax, dim=None)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "substantial-creator",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import numpy as np\n",
"\n",
"def timeit(module, *args, runs=5, verbose=False, **kwargs):\n",
" t = [] # accum times\n",
" result = None # final output\n",
" has_cuda = torch.cuda.is_available()\n",
" \n",
" # prepare everything to cuda\n",
" if has_cuda:\n",
" module = module.cuda()\n",
" args = [arg.cuda() for arg in args]\n",
"\n",
" # warm up the cpu/gpu\n",
" module(*args, **kwargs)\n",
" \n",
" for _ in range(runs):\n",
" # start timer\n",
" if has_cuda:\n",
" event_start = torch.cuda.Event(enable_timing=True)\n",
" event_end = torch.cuda.Event(enable_timing=True)\n",
" event_start.record()\n",
" else:\n",
" time_start = time.perf_counter()\n",
"\n",
" # computation\n",
" result = module(*args, **kwargs)\n",
"\n",
" # stop timer\n",
" if has_cuda:\n",
" event_end.record()\n",
" torch.cuda.synchronize(GPU_ID)\n",
" elapsed_time = event_start.elapsed_time(event_end) # in ms\n",
" elapsed_time /= 1000 # in s \n",
" else:\n",
" time_end = time.perf_counter()\n",
" elapsed_time = time_end - time_start # in s\n",
"\n",
" t.append(elapsed_time)\n",
" \n",
" # Averaged elapsed time\n",
" if verbose:\n",
" print('Avg. elapsed time: {}s (+/- {:.4f})'.format(np.mean(t), np.std(t)))\n",
" return np.mean(t), result"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "increasing-antibody",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.025603865400000016,\n",
" tensor([[0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 8.6379e-09, 0.0000e+00,\n",
" 4.7698e-09, 0.0000e+00],\n",
" [4.1017e-01, 5.8983e-01, 0.0000e+00, 7.6037e-08, 9.7740e-08, 0.0000e+00,\n",
" 0.0000e+00, 1.2660e-08]], grad_fn=<ClampBackward>))"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from functools import partial\n",
"sparsemax_fw_net = partial(small_net, prob_fn=sparsemax_fw)\n",
"sparsemax_sp_net = partial(small_net, prob_fn=entmax.sparsemax)\n",
"sparsemax_cvx_net = partial(small_net, prob_fn=cvx_sparsemax)\n",
"\n",
"timeit(sparsemax_fw_net, X1, dim=-1, verbose=False)\n",
"timeit(sparsemax_sp_net, X2, dim=-1, verbose=False)\n",
"timeit(sparsemax_cvx_net,X3, dim=None, verbose=False)"
]
},
{
"cell_type": "markdown",
"id": "opponent-geology",
"metadata": {},
"source": [
"## As hidden size grows"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "combined-investing",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"def plot_times(x, y1, y2, y3, xlabel='hidden size'):\n",
" to_ms = lambda z: np.array(z) * 1000\n",
" plt.plot(x, to_ms(y1), '-', label='forward-sparsemax')\n",
" plt.plot(x, to_ms(y2), '-', label='specialized-sparsemax')\n",
" plt.plot(x, to_ms(y3), '-', label='cvx-sparsemax')\n",
" plt.legend()\n",
" plt.xticks(x)\n",
" plt.xlabel(xlabel)\n",
" plt.ylabel(\"time (ms)\")\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "powerful-space",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 4\n",
"N = [2**i for i in range(2, 20)]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "certain-latino",
"metadata": {},
"outputs": [],
"source": [
"# it takes about 30s to run on my mac (no gpu)\n",
"t1s, t2s, t3s = [], [], []\n",
"for n in N:\n",
" X = torch.randn(batch_size, n)\n",
" \n",
" # forward-pass sparsemax\n",
" X1 = X.clone().requires_grad_(True)\n",
" t1, res1 = timeit(sparsemax_fw_net, X1, dim=-1, verbose=False)\n",
" t1s.append(t1)\n",
" \n",
" # specialzied sparsemax\n",
" X2 = X.clone().requires_grad_(True)\n",
" t2, res2 = timeit(sparsemax_fw_net, X2, dim=-1, verbose=False)\n",
" t2s.append(t2)\n",
" \n",
" # cvx sparsemax\n",
" X3 = X.clone().requires_grad_(True)\n",
" t3, res3 = timeit(sparsemax_fw_net, X3, dim=-1, verbose=False)\n",
" t3s.append(t3)\n",
" \n",
" # outputs should be close\n",
" assert torch.allclose(res1, res2)\n",
" assert torch.allclose(res2, res3)\n",
" # grads should be close\n",
" assert torch.allclose(X1.grad.data, X2.grad.data)\n",
" assert torch.allclose(X2.grad.data, X3.grad.data) "
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "accessible-christopher",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_times(N, t1s, t2s, t3s, xlabel='hidden size')"
]
},
{
"cell_type": "markdown",
"id": "suburban-longer",
"metadata": {},
"source": [
"## As batch size grows"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "conditional-guess",
"metadata": {},
"outputs": [],
"source": [
"B = [2**i for i in range(2, 11)]\n",
"n = 1024"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "fundamental-broadway",
"metadata": {},
"outputs": [],
"source": [
"# it takes about 30s to run on my mac (no gpu)\n",
"t1s, t2s, t3s = [], [], []\n",
"for batch_size in B:\n",
" X = torch.randn(batch_size, n)\n",
" \n",
" # forward-pass sparsemax\n",
" X1 = X.clone().requires_grad_(True)\n",
" t1, res1 = timeit(sparsemax_fw_net, X1, dim=-1, verbose=False)\n",
" t1s.append(t1)\n",
" \n",
" # specialzied sparsemax\n",
" X2 = X.clone().requires_grad_(True)\n",
" t2, res2 = timeit(sparsemax_fw_net, X2, dim=-1, verbose=False)\n",
" t2s.append(t2)\n",
" \n",
" # cvx sparsemax\n",
" X3 = X.clone().requires_grad_(True)\n",
" t3, res3 = timeit(sparsemax_fw_net, X3, dim=-1, verbose=False)\n",
" t3s.append(t3)\n",
" \n",
" # outputs should be close\n",
" assert torch.allclose(res1, res2)\n",
" assert torch.allclose(res2, res3)\n",
" \n",
" # grads should be close\n",
" assert torch.allclose(X1.grad.data, X2.grad.data)\n",
" assert torch.allclose(X2.grad.data, X3.grad.data) "
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "grave-fifteen",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_times(B, t1s, t2s, t3s, xlabel='batch size')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment