Created
November 15, 2017 18:49
-
-
Save john-bradshaw/089a352dd85d2a75b5dc090f8d0d6fa2 to your computer and use it in GitHub Desktop.
Simple example to run Bayesian Optimisation GPs. Works with Python3, TF1.4, python-fire, GPflow git commit 4ff00cbbc83efff8cb537f16a7eb1c1e11de3a75
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import numpy as np | |
from scipy import stats | |
import fire | |
import matplotlib | |
matplotlib.use('TkAgg') | |
import matplotlib.pyplot as plt | |
from matplotlib.widgets import Button | |
import gpflow as gpf | |
run_settings = { | |
"bounds": [0, 5], | |
"gp_lengthscale":1., | |
"measurement_noise_std": 0.05, | |
} | |
class OracleGPModel(object): | |
def __init__(self, rng: np.random.RandomState): | |
self.seen_points = np.array([[run_settings["bounds"][0]], [run_settings["bounds"][1]]], dtype=float) | |
self.seen_fs = np.array([[0.], [0.]], dtype=float) | |
self.seen_ys = np.array([[0.], [0.]], dtype=float) | |
self._rng = rng | |
self.model = gpf.models.GPR(self.seen_points, self.seen_ys, kern=gpf.kernels.RBF(1, lengthscales=run_settings['gp_lengthscale'])) | |
self.model.likelihood.variance = 1e-3 | |
self.model.likelihood.set_trainable(False) | |
def query(self, point): | |
print(point) | |
self.seen_points = np.concatenate([self.seen_points, np.atleast_2d(point)], axis=0) | |
seen_f = self.model.predict_f_samples(np.atleast_2d(point), 1)[0] | |
self.seen_fs = np.concatenate([self.seen_fs, seen_f], axis=0) | |
seen_y = seen_f + self._rng.randn(*np.atleast_2d(point).shape) * run_settings['measurement_noise_std'] | |
self.seen_ys = np.concatenate([self.seen_ys, seen_y], axis=0) | |
self.model.X.assign(self.seen_points) | |
self.model.Y.assign(self.seen_fs) | |
return seen_y | |
def plot(self, ax, plot_gp_parts=True): | |
x_range = np.linspace(*run_settings["bounds"], 100)[:, np.newaxis] | |
if plot_gp_parts: | |
mean, var = self.model.predict_f(x_range) | |
ax.plot(x_range, mean, '--', color='#2e267c', lw=3) | |
ax.fill_between(x_range[:, 0], mean[:, 0] - 2 * np.sqrt(var[:, 0]), mean[:, 0] + 2 * np.sqrt(var[:, 0]), | |
color='#37c2df', alpha=0.2) | |
ax.plot(self.seen_points, self.seen_ys, 'kx', ms=10, mew=5) | |
def run_querier_human_in_loop(): | |
rng = np.random.RandomState(100) | |
gp_model = OracleGPModel(rng) | |
fig, ax = plt.subplots() | |
gp_model.plot(ax) | |
def onclick(event): | |
print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % | |
('double' if event.dblclick else 'single', event.button, | |
event.x, event.y, event.xdata, event.ydata)) | |
gp_model.query(np.array([[event.xdata]])) | |
ax.clear() | |
gp_model.plot(ax) | |
plt.draw() | |
cid = fig.canvas.mpl_connect('button_press_event', onclick) | |
plt.show() | |
class AcquistionFunctionsTypes(enum.Enum): | |
THOMPSON = 0 | |
UCB = 1 | |
POI = 2 | |
class BayesianOptimiser(object): | |
def __init__(self, rng): | |
self.seen_points = np.array([[run_settings["bounds"][0]], [run_settings["bounds"][1]]], dtype=float) | |
self.seen_ys = np.array([[0.], [0.]], dtype=float) | |
self._rng = rng | |
self.model = gpf.models.GPR(self.seen_points, self.seen_ys, kern=gpf.kernels.RBF(1, lengthscales=run_settings['gp_lengthscale'])) | |
self.model.likelihood.variance = run_settings['measurement_noise_std'] | |
self.model.likelihood.set_trainable(False) | |
def add_sample(self, X, Y): | |
self.seen_points = np.concatenate([self.seen_points, np.atleast_2d(X)], axis=0) | |
self.seen_ys = np.concatenate([self.seen_ys, np.atleast_2d(Y)], axis=0) | |
self.model.X.assign(self.seen_points) | |
self.model.Y.assign(self.seen_ys) | |
def select_query(self, acq_type): | |
x_grid = np.linspace(*run_settings["bounds"], 250)[:, np.newaxis] | |
# ^ will evaluate on a grid and pick the best acquistion from that as a simple one d problem | |
f_m, f_var = self.model.predict_f(x_grid) | |
if acq_type == AcquistionFunctionsTypes.POI: | |
current_max = np.max(f_m) | |
acq = 1 - stats.norm.cdf(current_max, loc=f_m, scale=np.sqrt(f_var)) | |
elif acq_type is AcquistionFunctionsTypes.UCB: | |
kappa = 1 | |
acq = f_m + kappa * f_var | |
elif acq_type is AcquistionFunctionsTypes.THOMPSON: | |
acq = self.model.predict_f_samples(x_grid, 1)[0] | |
else: | |
raise NotImplementedError | |
loc = random_argmax(acq[:, 0], self._rng)[0] | |
return x_grid[loc:loc+1, 0:1], x_grid, acq | |
def random_argmax(array, rng): | |
""" | |
Like ordinary numpy argmax except picks one of the options if multiple rather than the first one. | |
:param array: the array to find the max of | |
:type array: np.ndarray | |
:param rng: random number generator | |
:type rng: np.random.RandomState | |
:return: 1-D array with the respective indices for each axis | |
""" | |
max_value = np.max(array) | |
max_locations = np.array(np.nonzero(max_value == array)).T | |
i_to_pick = rng.randint(max_locations.shape[0]) | |
return max_locations[i_to_pick] | |
def run_bayesian_optimisation(): | |
rng = np.random.RandomState(100) | |
oracle = OracleGPModel(rng) | |
fig, ax = plt.subplots(2, 1, sharex=True) | |
oracle.plot(ax[0]) | |
bayes_opt = BayesianOptimiser(rng) | |
number_times = [0] | |
query_out = [None] | |
def button_callback(_): | |
if number_times[0] % 2 == 0: | |
query, x_for_acq, acq = bayes_opt.select_query(AcquistionFunctionsTypes.THOMPSON) | |
ax[1].clear() | |
ax[1].plot(x_for_acq, acq, color='#9fca45') | |
ax_1_lims = (acq.min()-0.1, acq.max() + 0.1) | |
ax[1].plot([query[0,0]]*2, ax_1_lims, ':', color='#ed3956', lw=3) | |
query_out[0] = query | |
else: | |
seen_y = oracle.query(query_out[0]) | |
bayes_opt.add_sample(query_out[0], seen_y) | |
ax[0].clear() | |
oracle.plot(ax[0]) | |
ax_0_lims = ax[0].get_ylim() | |
ax[0].plot([query_out[0][0, 0]] * 2, ax_0_lims, ':', color='#ed3956', lw=3) | |
ax[0].plot(query_out[0], seen_y, 'x', ms=10, mew=5, color='#ed3956') | |
number_times[0] += 1 | |
fig.canvas.draw() | |
axnext = plt.axes([0.81, 0.05, 0.1, 0.075]) | |
bnext = Button(axnext, 'Next') | |
bnext.on_clicked(button_callback) | |
plt.show() | |
def run_opts(demo: str="human"): | |
if demo == "human": | |
run_querier_human_in_loop() | |
elif demo == "bo": | |
run_bayesian_optimisation() | |
else: | |
raise NotImplementedError | |
if __name__ == '__main__': | |
fire.Fire(run_opts) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment