Skip to content

Instantly share code, notes, and snippets.

@maxpagels
Created February 12, 2017 18:12
Show Gist options
  • Save maxpagels/f974d85b6a7dbb3b97193b384086e070 to your computer and use it in GitHub Desktop.
Save maxpagels/f974d85b6a7dbb3b97193b384086e070 to your computer and use it in GitHub Desktop.
import numpy as np
from flask import Flask
from flask import request
from flask import jsonify
# A simple implementation of a multi-armed bandit using Thompson Sampling.
class ThompsonBandit(object):
def __init__(self, n_arms):
self.wins = np.zeros(n_arms)
self.trials = np.zeros(n_arms)
self.N = 0
def pull(self):
arm = np.argmax(np.random.beta(1 + self.wins, 1 + self.trials - self.wins))
self.trials[arm] += 1
self.N += 1
return arm
def reward(self, arm):
if self.wins[arm] <= self.N and self.trials[arm] > self.wins[arm]:
self.wins[arm] += 1
app = Flask(__name__)
bandit = ThompsonBandit(10)
@app.route("/get_arm", methods=['GET'])
def test():
arm = bandit.pull()
return jsonify({'arm': arm.item()})
@app.route('/reward', methods=['POST'])
def reward():
bandit.reward(request.get_json()['arm'])
return jsonify({'message': 'reward sent successfully'})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment