Skip to content

Instantly share code, notes, and snippets.

@jkbjh
Last active October 25, 2019 11:46
Show Gist options
  • Save jkbjh/fb8ce6e53878d0f71c549201d1184d83 to your computer and use it in GitHub Desktop.
Save jkbjh/fb8ce6e53878d0f71c549201d1184d83 to your computer and use it in GitHub Desktop.
# cross-entropy optimization
import numpy as np
import tqdm
def ce_minimize(X_mean, X_std, eval_loss, iterations=100, sample_size=25, keep_samples=10):
assert 0 < sample_size
assert 0 < keep_samples and keep_samples < sample_size
# Alg settings:
trange = tqdm.trange(iterations)
for i in trange:
# sample parameter vectors
Xs = np.random.multivariate_normal(mean=X_mean,
cov=np.diag(np.array(X_std**2)),
size=sample_size)
losses = np.array([eval_loss(x) for x in Xs])
# keep the best samples
keep_inds = losses.argsort()[:keep_samples]
keep_Xs = Xs[keep_inds]
X_mean = keep_Xs.mean(axis=0)
X_std = keep_Xs.std(axis=0)
trange.set_description("mean loss: %8.3g" % (np.mean(losses)))
return X_mean
def test():
def eval_loss(x):
one_x, = x
loss = (one_x - 5.)**2.
return loss
X_mean = np.array([0.])
X_std = np.array([4.])
X = ce_minimize(X_mean, X_std, eval_loss)
np.testing.assert_array_almost_equal(np.array([5.]), X)
print(X)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment