Skip to content

Instantly share code, notes, and snippets.

@john-bradshaw
Created November 15, 2017 18:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save john-bradshaw/089a352dd85d2a75b5dc090f8d0d6fa2 to your computer and use it in GitHub Desktop.
Save john-bradshaw/089a352dd85d2a75b5dc090f8d0d6fa2 to your computer and use it in GitHub Desktop.
Simple example to run Bayesian Optimisation GPs. Works with Python3, TF1.4, python-fire, GPflow git commit 4ff00cbbc83efff8cb537f16a7eb1c1e11de3a75
import enum
import numpy as np
from scipy import stats
import fire
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import gpflow as gpf
run_settings = {
"bounds": [0, 5],
"gp_lengthscale":1.,
"measurement_noise_std": 0.05,
}
class OracleGPModel(object):
def __init__(self, rng: np.random.RandomState):
self.seen_points = np.array([[run_settings["bounds"][0]], [run_settings["bounds"][1]]], dtype=float)
self.seen_fs = np.array([[0.], [0.]], dtype=float)
self.seen_ys = np.array([[0.], [0.]], dtype=float)
self._rng = rng
self.model = gpf.models.GPR(self.seen_points, self.seen_ys, kern=gpf.kernels.RBF(1, lengthscales=run_settings['gp_lengthscale']))
self.model.likelihood.variance = 1e-3
self.model.likelihood.set_trainable(False)
def query(self, point):
print(point)
self.seen_points = np.concatenate([self.seen_points, np.atleast_2d(point)], axis=0)
seen_f = self.model.predict_f_samples(np.atleast_2d(point), 1)[0]
self.seen_fs = np.concatenate([self.seen_fs, seen_f], axis=0)
seen_y = seen_f + self._rng.randn(*np.atleast_2d(point).shape) * run_settings['measurement_noise_std']
self.seen_ys = np.concatenate([self.seen_ys, seen_y], axis=0)
self.model.X.assign(self.seen_points)
self.model.Y.assign(self.seen_fs)
return seen_y
def plot(self, ax, plot_gp_parts=True):
x_range = np.linspace(*run_settings["bounds"], 100)[:, np.newaxis]
if plot_gp_parts:
mean, var = self.model.predict_f(x_range)
ax.plot(x_range, mean, '--', color='#2e267c', lw=3)
ax.fill_between(x_range[:, 0], mean[:, 0] - 2 * np.sqrt(var[:, 0]), mean[:, 0] + 2 * np.sqrt(var[:, 0]),
color='#37c2df', alpha=0.2)
ax.plot(self.seen_points, self.seen_ys, 'kx', ms=10, mew=5)
def run_querier_human_in_loop():
rng = np.random.RandomState(100)
gp_model = OracleGPModel(rng)
fig, ax = plt.subplots()
gp_model.plot(ax)
def onclick(event):
print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
('double' if event.dblclick else 'single', event.button,
event.x, event.y, event.xdata, event.ydata))
gp_model.query(np.array([[event.xdata]]))
ax.clear()
gp_model.plot(ax)
plt.draw()
cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()
class AcquistionFunctionsTypes(enum.Enum):
THOMPSON = 0
UCB = 1
POI = 2
class BayesianOptimiser(object):
def __init__(self, rng):
self.seen_points = np.array([[run_settings["bounds"][0]], [run_settings["bounds"][1]]], dtype=float)
self.seen_ys = np.array([[0.], [0.]], dtype=float)
self._rng = rng
self.model = gpf.models.GPR(self.seen_points, self.seen_ys, kern=gpf.kernels.RBF(1, lengthscales=run_settings['gp_lengthscale']))
self.model.likelihood.variance = run_settings['measurement_noise_std']
self.model.likelihood.set_trainable(False)
def add_sample(self, X, Y):
self.seen_points = np.concatenate([self.seen_points, np.atleast_2d(X)], axis=0)
self.seen_ys = np.concatenate([self.seen_ys, np.atleast_2d(Y)], axis=0)
self.model.X.assign(self.seen_points)
self.model.Y.assign(self.seen_ys)
def select_query(self, acq_type):
x_grid = np.linspace(*run_settings["bounds"], 250)[:, np.newaxis]
# ^ will evaluate on a grid and pick the best acquistion from that as a simple one d problem
f_m, f_var = self.model.predict_f(x_grid)
if acq_type == AcquistionFunctionsTypes.POI:
current_max = np.max(f_m)
acq = 1 - stats.norm.cdf(current_max, loc=f_m, scale=np.sqrt(f_var))
elif acq_type is AcquistionFunctionsTypes.UCB:
kappa = 1
acq = f_m + kappa * f_var
elif acq_type is AcquistionFunctionsTypes.THOMPSON:
acq = self.model.predict_f_samples(x_grid, 1)[0]
else:
raise NotImplementedError
loc = random_argmax(acq[:, 0], self._rng)[0]
return x_grid[loc:loc+1, 0:1], x_grid, acq
def random_argmax(array, rng):
"""
Like ordinary numpy argmax except picks one of the options if multiple rather than the first one.
:param array: the array to find the max of
:type array: np.ndarray
:param rng: random number generator
:type rng: np.random.RandomState
:return: 1-D array with the respective indices for each axis
"""
max_value = np.max(array)
max_locations = np.array(np.nonzero(max_value == array)).T
i_to_pick = rng.randint(max_locations.shape[0])
return max_locations[i_to_pick]
def run_bayesian_optimisation():
rng = np.random.RandomState(100)
oracle = OracleGPModel(rng)
fig, ax = plt.subplots(2, 1, sharex=True)
oracle.plot(ax[0])
bayes_opt = BayesianOptimiser(rng)
number_times = [0]
query_out = [None]
def button_callback(_):
if number_times[0] % 2 == 0:
query, x_for_acq, acq = bayes_opt.select_query(AcquistionFunctionsTypes.THOMPSON)
ax[1].clear()
ax[1].plot(x_for_acq, acq, color='#9fca45')
ax_1_lims = (acq.min()-0.1, acq.max() + 0.1)
ax[1].plot([query[0,0]]*2, ax_1_lims, ':', color='#ed3956', lw=3)
query_out[0] = query
else:
seen_y = oracle.query(query_out[0])
bayes_opt.add_sample(query_out[0], seen_y)
ax[0].clear()
oracle.plot(ax[0])
ax_0_lims = ax[0].get_ylim()
ax[0].plot([query_out[0][0, 0]] * 2, ax_0_lims, ':', color='#ed3956', lw=3)
ax[0].plot(query_out[0], seen_y, 'x', ms=10, mew=5, color='#ed3956')
number_times[0] += 1
fig.canvas.draw()
axnext = plt.axes([0.81, 0.05, 0.1, 0.075])
bnext = Button(axnext, 'Next')
bnext.on_clicked(button_callback)
plt.show()
def run_opts(demo: str="human"):
if demo == "human":
run_querier_human_in_loop()
elif demo == "bo":
run_bayesian_optimisation()
else:
raise NotImplementedError
if __name__ == '__main__':
fire.Fire(run_opts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment