Skip to content

Instantly share code, notes, and snippets.

@mzmttks
Created October 22, 2015 16:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mzmttks/ead0951367e39c4ffba6 to your computer and use it in GitHub Desktop.
Save mzmttks/ead0951367e39c4ffba6 to your computer and use it in GitHub Desktop.
A sample implementation of Bernoulli Bandit Solver using Thompson Sampling
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import scipy\n",
"import scipy.stats\n",
"import numpy.random\n",
"import collections\n",
"import pprint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Thompson Sampling\n",
"## Bernoulli Bandits\n",
"\n",
"### 式\n",
"* 各 Bandit は期待値 ¥mu_i となる、成功か失敗かの2つのうち一つを出す。\n",
"\n",
"* Prior はベータ分布 --> ベルヌーイ分布の場合事後確率もベータ分布になるから。\n",
"* Beta(¥theta | ¥alpha, ¥beta)\n",
" * ¥alpha = Sat + 1\n",
" * ¥beta = Fat + 1\n",
"\n",
"* Sat: Number of successes of bandit a at time t\n",
"* Fat: Number of failures of bandit a at time t\n",
"\n",
"### References\n",
" * http://www.economics.uci.edu/~ivan/asmb.874.pdf\n",
" * S. Agrawal and N. Goyal: Analysis of Thompson Sampling for the Multi-armed Bandit Problem, JMLR, vol23, pp.39.1-39.26, 2012"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# set parameters\n",
"\n",
"numBandits = 3\n",
"numTrials = 5000\n",
"\n",
"rewards = [0.9, 0.9, 0.4] # expected rewards (hidden)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Banding machine simulator.\n",
"\n",
"def getResult(prob):\n",
" return scipy.stats.bernoulli.rvs(prob, size=1)[0]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Multiarm Bandit Problem solver using Thompson Sampling.\n",
"\n",
"class BernoulliBandit:\n",
" def __init__ (self, numBandits):\n",
" self.numBandits = numBandits\n",
" self.results = dict([(i, []) for i in range(self.numBandits)])\n",
" self.posteriorHistory = []\n",
" \n",
" def getBandit(self):\n",
" posteriors = []\n",
" for b in range(self.numBandits):\n",
" Sat = len([r for r in self.results[b] if r == 1])\n",
" Fat = len([r for r in self.results[b] if r == 0]) \n",
" posteriors.append(numpy.random.beta(Sat+1, Fat+1))\n",
" return numpy.array(posteriors).argmax()\n",
" \n",
" def feed(self, selectedBandit, result):\n",
" self.results[selectedBandit].append(result)\n",
" \n",
" def __str__(self):\n",
" out = \"\"\n",
" for b in range(self.numBandits):\n",
" Sat = len([r for r in self.results[b] if r == 1])\n",
" Fat = len([r for r in self.results[b] if r == 0]) \n",
" out += \"Bandit[%d] Fails: %4d\\t Successes: %6d\\n\" % (b, Sat, Fat)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 100 200 300 400 500 600 700 800 900 1000 1100 1200 1300 1400 1500 1600 1700 1800 1900 2000 2100 2200 2300 2400 2500 2600 2700 2800 2900 3000 3100 3200 3300 3400 3500 3600 3700 3800 3900 4000 4100 4200 4300 4400 4500 4600 4700 4800 4900\n",
"Bandit[0] Fails: 2134\t Successes: 245\n",
"Bandit[1] Fails: 2354\t Successes: 264\n",
"Bandit[2] Fails: 0\t Successes: 3\n",
"\n",
"[0.9, 0.9, 0.4]\n"
]
}
],
"source": [
"bandit = BernoulliBandit(numBandits)\n",
"for t in range(numTrials):\n",
" if t % 100 == 0:\n",
" print t,\n",
" b = bandit.getBandit()\n",
" r = getResult(rewards[b])\n",
" bandit.feed(b, r)\n",
"print\n",
"print bandit\n",
"print rewards"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment