Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save motokimura/1a90c0b8c5628914b99a81cd91369636 to your computer and use it in GitHub Desktop.
Save motokimura/1a90c0b8c5628914b99a81cd91369636 to your computer and use it in GitHub Desktop.
Quantization error simulation of SiLU (Swish) activation
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def fake_quantize(x):\n",
" min_ = np.percentile(x, 0)\n",
" max_ = np.percentile(x, 100)\n",
" x_clip = np.clip(x, a_min=min_, a_max=max_)\n",
" x_int8 = ((x_clip - min_) / (max_ - min_) * 255).astype(np.uint8)\n",
" return (x_int8.astype(float) / 255) * (max_ - min_) + min_"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def relu(x, quantize=False):\n",
" quantize_fn = fake_quantize if quantize else (lambda x: x)\n",
" return quantize_fn(x.clip(min=0))\n",
"\n",
"\n",
"def sigmoid(x):\n",
" return 1 / (1 + np.exp(-x))\n",
"\n",
"\n",
"def silu(x, quantize=False):\n",
" quantize_fn = fake_quantize if quantize else (lambda x: x)\n",
" h1 = quantize_fn(x)\n",
" h2 = quantize_fn(sigmoid(h1))\n",
" return quantize_fn(h1 * h2)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def compute_sqnr(x, q):\n",
" return 20 * np.log10(np.linalg.norm(x) / np.linalg.norm(x - q))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ReLU SQNR[dB]: 41.68666755974384\n",
"SiLU SQNR[dB]: 30.710681151668233\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"mu = 0\n",
"sigma = 1.0\n",
"x = np.random.normal(mu, sigma, 1000)\n",
"\n",
"bins = np.linspace(0, 0.045, 50)\n",
"plt.hist(np.abs(relu(x) - relu(x, quantize=True)), bins, label='ReLU', alpha=0.6)\n",
"plt.hist(np.abs(silu(x) - silu(x, quantize=True)), bins, label='SiLU', alpha=0.6)\n",
"plt.legend()\n",
"plt.xlabel('quantization error: abs(float - fake_quantized)');\n",
"\n",
"print('ReLU SQNR[dB]:', compute_sqnr(relu(x), relu(x, quantize=True)))\n",
"print('SiLU SQNR[dB]:', compute_sqnr(silu(x), silu(x, quantize=True)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment