Skip to content

Instantly share code, notes, and snippets.

@jph00
Created January 21, 2018 17:53
Show Gist options
  • Save jph00/30cfed589a8008325eae8f36e2c5b087 to your computer and use it in GitHub Desktop.
Save jph00/30cfed589a8008325eae8f36e2c5b087 to your computer and use it in GitHub Desktop.
Fast weighted sampling using the alias method in numba
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy.random as npr, numpy as np\nfrom numba import jit",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "@jit(nopython=True)\ndef sample(n, q, J, r1, r2):\n res = np.zeros(n, dtype=np.int32)\n lj = len(J)\n for i in range(n):\n kk = int(np.floor(r1[i]*lj))\n if r2[i] < q[kk]: res[i] = kk\n else: res[i] = J[kk]\n return res\n\nclass AliasSample():\n def __init__(self, probs):\n self.K=K= len(probs)\n self.q=q= np.zeros(K)\n self.J=J= np.zeros(K, dtype=np.int)\n\n smaller,larger = [],[]\n for kk, prob in enumerate(probs):\n q[kk] = K*prob\n if q[kk] < 1.0: smaller.append(kk)\n else: larger.append(kk)\n\n while len(smaller) > 0 and len(larger) > 0:\n small,large = smaller.pop(),larger.pop()\n J[small] = large\n q[large] = q[large] - (1.0 - q[small])\n if q[large] < 1.0: smaller.append(large)\n else: larger.append(large)\n\n def draw_one(self):\n K,q,J = self.K,self.q,self.J\n kk = int(np.floor(npr.rand()*len(J)))\n if npr.rand() < q[kk]: return kk\n else: return J[kk]\n\n def draw_n(self, n):\n r1,r2 = npr.rand(n),npr.rand(n)\n return sample(n,self.q,self.J,r1,r2)",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# some weights to do weighted sampling by\nprs = npr.random(30000)\nprs/=prs.sum()",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "a = AliasSample(prs)",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit t = a.draw_n(5000)",
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": "172 µs ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit t = np.random.choice(len(prs), 5000, p=prs)",
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": "988 µs ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "cum_prs = prs.cumsum()",
"execution_count": 13,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%timeit t = np.searchsorted(cum_prs, np.random.random(5000))",
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": "640 µs ± 7.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.4",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"toc": {
"threshold": 4,
"number_sections": true,
"toc_cell": false,
"toc_window_display": false,
"toc_section_display": "block",
"sideBar": true,
"navigate_menu": true,
"moveMenuLeft": true,
"widenNotebook": false,
"colors": {
"hover_highlight": "#DAA520",
"selected_highlight": "#FFD700",
"running_highlight": "#FF0000",
"wrapper_background": "#FFFFFF",
"sidebar_border": "#EEEEEE",
"navigate_text": "#333333",
"navigate_num": "#000000"
},
"nav_menu": {
"width": "252px",
"height": "12px"
}
},
"gist": {
"id": "",
"data": {
"description": "Fast alias sampling with Numba",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment