Skip to content

Instantly share code, notes, and snippets.

@nickodell
Last active July 31, 2023 02:16
Show Gist options
  • Save nickodell/490d22cb3077fde89dceb55cd32db38a to your computer and use it in GitHub Desktop.
Save nickodell/490d22cb3077fde89dceb55cd32db38a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "1ab0e3df",
"metadata": {},
"source": [
"# Quantizing normally distributed floats in Python and NumPy\n",
"https://stackoverflow.com/questions/76798643/quantizing-normally-distributed-floats-in-python-and-numpy"
]
},
{
"cell_type": "markdown",
"id": "eaf966fd",
"metadata": {},
"source": [
"Let the values in the array `A` be sampled from a Gaussian\n",
"distribution. I want to replace every value in `A` with one of `n_R`\n",
"\"representatives\" in `R` so that the total quantization error is\n",
"minimized.\n",
"\n",
"Here is NumPy code that does linear quantization:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b2c36039",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.stats import norm\n",
"import math\n",
"import sklearn.preprocessing\n",
"import jenkspy\n",
"import scipy.optimize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "07717fe5",
"metadata": {},
"outputs": [],
"source": [
"n_A, n_R = 1_000_000, 256\n",
"mu, sig = 500, 250\n",
"A = np.random.normal(mu, sig, size = n_A)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5284bf49",
"metadata": {},
"outputs": [],
"source": [
"# Linear method\n",
"lo, hi = np.min(A), np.max(A)\n",
"linear_R = np.linspace(lo, hi, n_R)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a35a82b4",
"metadata": {},
"outputs": [],
"source": [
"def measure_loss(R):\n",
" lhits = np.clip(np.searchsorted(R, A, 'left'), 0, n_R - 1)\n",
" rhits = np.clip(np.searchsorted(R, A, 'right') - 1, 0, n_R - 1)\n",
" ldiff = R[lhits] - A\n",
" rdiff = A - R[rhits]\n",
" I = lhits\n",
" idx = np.where(rdiff < ldiff)[0]\n",
" I[idx] = rhits[idx]\n",
"\n",
"# plt.plot(A - R[I])\n",
" L = np.mean(np.abs(A - R[I]))\n",
" return L"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "00efe98a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear method: 2.446699788910965\n"
]
}
],
"source": [
"print('Linear method:', measure_loss(linear_R))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "02da61f1",
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import norm\n",
"\n",
"dist = norm(loc = mu, scale = sig)\n",
"bounds = dist.cdf([mu - 3*sig, mu + 3*sig])\n",
"pp = np.linspace(*bounds, n_R)\n",
"gaussian_R = dist.ppf(pp)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cfe549ed",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gaussian method: 1.6443746592274302\n"
]
}
],
"source": [
"print('Gaussian method:', measure_loss(gaussian_R))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "94336abe",
"metadata": {},
"outputs": [],
"source": [
"A_samp = np.random.choice(A, size = 10000)\n",
"breaks = jenkspy.jenks_breaks(A_samp, n_classes=n_R-1)\n",
"jenks_R = np.array(breaks)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "62c120c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Jenks method: 1.2877975273144864\n"
]
}
],
"source": [
"print('Jenks method:', measure_loss(jenks_R))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "213f6453",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plot quantiles of R within distribution against index of R\n",
"\n",
"def quantile_plot(R, name):\n",
" x = np.arange(len(R))\n",
" dist = norm(loc = mu, scale = sig)\n",
" y = dist.cdf(R)\n",
" plt.plot(x, y, label=name)\n",
" plt.xlabel('Class index')\n",
" plt.ylabel('Distribution quantile')\n",
" plt.legend()\n",
"\n",
"lo, hi = np.min(A), np.max(A)\n",
"linear_R = np.linspace(lo, hi, n_R)\n",
"quantile_plot(linear_R, 'linear')\n",
"\n",
"dist = norm(loc = mu, scale = sig)\n",
"bounds = dist.cdf([mu - 3*sig, mu + 3*sig])\n",
"pp = np.linspace(*bounds, n_R)\n",
"gaussian_R = dist.ppf(pp)\n",
"quantile_plot(gaussian_R, 'gaussian')\n",
"\n",
"quantile_plot(np.array(breaks), 'jenks')\n",
"\n",
"# plot synthetic sigmoid curve\n",
"# synthetic_R = sigmoid_classes(7.15067561, n_R)\n",
"# quantile_plot(synthetic_R, 'synthetic')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4bc1984b",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Estimated curve strength of linear: 16.7025\n",
"Average error 0.235003\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Estimated curve strength of gaussian: 0.0005\n",
"Average error 0.000154\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Estimated curve strength of jenks: 6.7242\n",
"Average error 0.020398\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Define sigmoid function, determine curve strength of methods so far\n",
"\n",
"def sigmoid(x, strength):\n",
" scaler_input = sklearn.preprocessing.MinMaxScaler(feature_range=(-0.5, 0.5))\n",
" # Avoid having points at inf\n",
" epsilon = 1e-5\n",
" x = scaler_input.fit_transform(x.reshape(-1, 1)).flatten()\n",
" if strength != 0:\n",
" output = 1 / (1 + np.exp(-(x * strength)))\n",
" else:\n",
" # 0-strength line. Make sure final output has some variance\n",
" output = np.linspace(0, 1, len(x))\n",
" scaler_output = sklearn.preprocessing.MinMaxScaler(feature_range=(0 + epsilon, 1 - epsilon))\n",
" return scaler_output.fit_transform(output.reshape(-1, 1)).reshape(-1)\n",
"\n",
"def sse(y_true, y_pred):\n",
" return np.sum((y_true - y_pred) ** 2)\n",
"\n",
"def sigmoid_classes(strength, n_classes):\n",
" x = np.arange(n_classes)\n",
" return dist.ppf(sigmoid(x, strength))\n",
"\n",
"def est_curve_strength(R, name):\n",
" y = dist.cdf(R)\n",
" x = np.arange(len(R))\n",
" res = scipy.optimize.minimize(lambda x0: sse(y, sigmoid(x, x0[0])), x0=[10], options={'gtol': 1e-10})\n",
" print(f\"Estimated curve strength of {name}: {res.x[0]:.4f}\")\n",
" print(f\"Average error {res.fun:.6f}\")\n",
" plt.title(name)\n",
" plt.plot(x, sigmoid(x, *res.x), label='fit')\n",
" plt.plot(x, y, label='true')\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
"est_curve_strength(linear_R, 'linear')\n",
"est_curve_strength(gaussian_R, 'gaussian')\n",
"est_curve_strength(jenks_R, 'jenks')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1809fa9d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" fun: 1.3198519807543092\n",
" hess_inv: array([[0.04675019]])\n",
" jac: array([6.63101673e-06])\n",
" message: 'Optimization terminated successfully.'\n",
" nfev: 44\n",
" nit: 7\n",
" njev: 22\n",
" status: 0\n",
" success: True\n",
" x: array([6.48649694])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"initial_strength = 16 # linear strength\n",
"res = scipy.optimize.minimize(lambda x0: measure_loss(sigmoid_classes(x0[0], n_R)), x0=[initial_strength])\n",
"sigmoid_R = sigmoid_classes(res.x[0], n_R)\n",
"res"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "48be5c0c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fdc8859fb80>]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(sigmoid_R)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "111ba97d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# plot Jenks vs. optimized\n",
"quantile_plot(sigmoid_R, 'sigmoid')\n",
"quantile_plot(jenks_R, 'jenks')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fc381f5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment