Last active
October 25, 2019 11:46
-
-
Save jkbjh/fb8ce6e53878d0f71c549201d1184d83 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cross-entropy optimization | |
import numpy as np | |
import tqdm | |
def ce_minimize(X_mean, X_std, eval_loss, iterations=100, sample_size=25, keep_samples=10): | |
assert 0 < sample_size | |
assert 0 < keep_samples and keep_samples < sample_size | |
# Alg settings: | |
trange = tqdm.trange(iterations) | |
for i in trange: | |
# sample parameter vectors | |
Xs = np.random.multivariate_normal(mean=X_mean, | |
cov=np.diag(np.array(X_std**2)), | |
size=sample_size) | |
losses = np.array([eval_loss(x) for x in Xs]) | |
# keep the best samples | |
keep_inds = losses.argsort()[:keep_samples] | |
keep_Xs = Xs[keep_inds] | |
X_mean = keep_Xs.mean(axis=0) | |
X_std = keep_Xs.std(axis=0) | |
trange.set_description("mean loss: %8.3g" % (np.mean(losses))) | |
return X_mean | |
def test(): | |
def eval_loss(x): | |
one_x, = x | |
loss = (one_x - 5.)**2. | |
return loss | |
X_mean = np.array([0.]) | |
X_std = np.array([4.]) | |
X = ce_minimize(X_mean, X_std, eval_loss) | |
np.testing.assert_array_almost_equal(np.array([5.]), X) | |
print(X) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment