Created
March 4, 2019 23:40
-
-
Save rcurtin/a1deacc1debefbcd90b9d280364469a0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From 3fbbbc05e9fc2825b53bc649c9f1593f8c9f2d84 Mon Sep 17 00:00:00 2001 | |
From: Ryan Curtin <ryan@ratml.org> | |
Date: Mon, 4 Mar 2019 18:38:32 -0500 | |
Subject: [PATCH] Fixes for SPSA and its documentation. | |
--- | |
HISTORY.md | 4 + | |
doc/function_types.md | 1 + | |
doc/optimizers.md | 14 ++-- | |
include/ensmallen_bits/spsa/spsa.hpp | 25 ++---- | |
include/ensmallen_bits/spsa/spsa_impl.hpp | 94 +++++++++-------------- | |
tests/spsa_test.cpp | 63 +++++++-------- | |
6 files changed, 83 insertions(+), 118 deletions(-) | |
diff --git a/HISTORY.md b/HISTORY.md | |
index 5fbc3a2..b0c0071 100644 | |
--- a/HISTORY.md | |
+++ b/HISTORY.md | |
@@ -1,3 +1,7 @@ | |
+### ensmallen 1.14.1 | |
+###### ????-??-?? | |
+ * Fixes for SPSA (#87). | |
+ | |
### ensmallen 1.14.0 | |
###### 2019-02-20 | |
* Add DE optimizer (#77). | |
diff --git a/doc/function_types.md b/doc/function_types.md | |
index 5e75a00..bef593d 100644 | |
--- a/doc/function_types.md | |
+++ b/doc/function_types.md | |
@@ -24,6 +24,7 @@ The following optimizers can be used to optimize an arbitrary function: | |
- [Simulated Annealing](#simulated-annealing-sa) | |
- [CNE](#cne) | |
- [DE](#de) | |
+ - [SPSA](#simultaneous-perturbation-stochastic-approximation-spsa) | |
Each of these optimizers has an `Optimize()` function that is called as | |
`Optimize(f, x)` where `f` is the function to be optimized (which implements | |
diff --git a/doc/optimizers.md b/doc/optimizers.md | |
index 9193ee0..c87c133 100644 | |
--- a/doc/optimizers.md | |
+++ b/doc/optimizers.md | |
@@ -1510,30 +1510,28 @@ optimizer.Optimize(f, coordinates); | |
## Simultaneous Perturbation Stochastic Approximation (SPSA) | |
-*An optimizer for [differentiable separable functions](#differentiable-separable-functions).* | |
+*An optimizer for [arbitrary functions](#arbitrary-functions).* | |
The SPSA algorithm approximates the gradient of the function by finite | |
differences along stochastic directions. | |
-### Constructors | |
+#### Constructors | |
- * `SPSA(`_`alpha, batchSize, gamma, stepSize, evaluationStepSize, maxIterations, tolerance, shuffle`_`)` | |
+ * `SPSA(`_`alpha, gamma, stepSize, evaluationStepSize, maxIterations, tolerance, shuffle`_`)` | |
- #### Attributes | |
+#### Attributes | |
| **type** | **name** | **description** | **default** | | |
|----------|----------|-----------------|-------------| | |
| `double` | **`alpha`** | Scaling exponent for the step size. | `0.602` | | |
-| `size_t` | **`batchSize`** | Batch size to use for each step. | `32` | | |
| `double` | **`gamma`** | Scaling exponent for evaluation step size. | `0.101` | | |
| `double` | **`stepSize`** | Scaling parameter for step size (named as 'a' in the paper). | `0.16` | | |
| `double` | **`evaluationStepSize`** | Scaling parameter for evaluation step size (named as 'c' in the paper). | `0.3` | | |
| `size_t` | **`maxIterations`** | Maximum number of iterations allowed (0 means no limit). | `100000` | | |
| `double` | **`tolerance`** | Maximum absolute tolerance to terminate algorithm. | `1e-5` | | |
-| `bool` | **`shuffle`** | If true, the function order is shuffled; otherwise, each function is visited in linear order. | `true` | | |
Attributes of the optimizer may also be changed via the member methods | |
-`Alpha()`, `BatchSize()`, `Gamma()`, `StepSize()`, `EvaluationStepSize()`, and `MaxIterations()`. | |
+`Alpha()`, `Gamma()`, `StepSize()`, `EvaluationStepSize()`, and `MaxIterations()`. | |
#### Examples: | |
@@ -1541,7 +1539,7 @@ Attributes of the optimizer may also be changed via the member methods | |
SphereFunction f(2); | |
arma::mat coordinates = f.GetInitialPoint(); | |
-SPSA optimizer(0.1, 2, 0.102, 0.16, 0.3, 100000, 0); | |
+SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 1e-5); | |
optimizer.Optimize(f, coordinates); | |
``` | |
diff --git a/include/ensmallen_bits/spsa/spsa.hpp b/include/ensmallen_bits/spsa/spsa.hpp | |
index ce1fa5e..32d1053 100644 | |
--- a/include/ensmallen_bits/spsa/spsa.hpp | |
+++ b/include/ensmallen_bits/spsa/spsa.hpp | |
@@ -3,8 +3,7 @@ | |
* @author N Rajiv Vaidyanathan | |
* @author Marcus Edel | |
* | |
- * SPSA (Simultaneous perturbation stochastic approximation) method for | |
- * faster convergence. | |
+ * SPSA (Simultaneous perturbation stochastic approximation) method. | |
* | |
* ensmallen is free software; you may redistribute it and/or modify it under | |
* the terms of the 3-clause BSD license. You should have received a copy of | |
@@ -35,7 +34,7 @@ namespace ens { | |
* } | |
* @endcode | |
* | |
- * SPSA can optimize differentiable separable functions. For more details, | |
+ * SPSA can optimize arbitrary functions. For more details, | |
* see the documentation on function types included with this distribution or on | |
* the ensmallen website. | |
*/ | |
@@ -45,16 +44,13 @@ class SPSA | |
/** | |
* Construct the SPSA optimizer with the given function and parameters. The | |
* defaults here are not necessarily good for the given problem, so it is | |
- * suggested that the values used be tailored to the task at hand. The | |
- * maximum number of iterations refers to the maximum number of points that | |
- * are processed (i.e., one iteration equals one point; one iteration does not | |
- * equal one pass over the dataset). | |
+ * suggested that the values used be tailored to the task at hand. | |
* | |
* @param alpha Scaling exponent for the step size. | |
- * @param batchSize Batch size to use for each step. | |
* @param gamma Scaling exponent for evaluation step size. | |
* @param stepSize Scaling parameter for step size (named as 'a' in the paper). | |
- * @param evaluationStepSize Scaling parameter for evaluation step size (named as 'c' in the paper). | |
+ * @param evaluationStepSize Scaling parameter for evaluation step size (named | |
+ * as 'c' in the paper). | |
* @param maxIterations Maximum number of iterations allowed (0 means no | |
* limit). | |
* @param tolerance Maximum absolute tolerance to terminate algorithm. | |
@@ -62,7 +58,6 @@ class SPSA | |
* function is visited in linear order. | |
*/ | |
SPSA(const double alpha = 0.602, | |
- const size_t batchSize = 32, | |
const double gamma = 0.101, | |
const double stepSize = 0.16, | |
const double evaluationStepSize = 0.3, | |
@@ -78,11 +73,6 @@ class SPSA | |
//! Modify the scaling exponent for the step size. | |
double& Alpha() { return alpha; } | |
- //! Get the batch size. | |
- size_t BatchSize() const { return batchSize; } | |
- //! Modify the batch size. | |
- size_t& BatchSize() { return batchSize; } | |
- | |
//! Get the scaling exponent for evaluation step size. | |
double Gamma() const { return gamma; } | |
//! Modify the scaling exponent for evaluation step size. | |
@@ -107,9 +97,6 @@ class SPSA | |
//! Scaling exponent for the step size. | |
double alpha; | |
- //! The batch size for processing. | |
- size_t batchSize; | |
- | |
//! Scaling exponent for evaluation step size. | |
double gamma; | |
@@ -120,7 +107,7 @@ class SPSA | |
double evaluationStepSize; | |
//! Control the amount of gradient update. | |
- double Ak; | |
+ double ak; | |
//! The maximum number of allowed iterations. | |
size_t maxIterations; | |
diff --git a/include/ensmallen_bits/spsa/spsa_impl.hpp b/include/ensmallen_bits/spsa/spsa_impl.hpp | |
index 92f4c94..df92326 100644 | |
--- a/include/ensmallen_bits/spsa/spsa_impl.hpp | |
+++ b/include/ensmallen_bits/spsa/spsa_impl.hpp | |
@@ -22,7 +22,6 @@ | |
namespace ens { | |
inline SPSA::SPSA(const double alpha, | |
- const size_t batchSize, | |
const double gamma, | |
const double stepSize, | |
const double evaluationStepSize, | |
@@ -30,23 +29,22 @@ inline SPSA::SPSA(const double alpha, | |
const double tolerance, | |
const bool shuffle) : | |
alpha(alpha), | |
- batchSize(batchSize), | |
gamma(gamma), | |
stepSize(stepSize), | |
evaluationStepSize(evaluationStepSize), | |
- Ak(0.001 * maxIterations), | |
+ ak(0.001 * maxIterations), | |
maxIterations(maxIterations), | |
tolerance(tolerance), | |
shuffle(shuffle) | |
{ /* Nothing to do. */ } | |
-template<typename DecomposableFunctionType> | |
+template<typename ArbitraryFunctionType> | |
inline double SPSA::Optimize( | |
- DecomposableFunctionType& function, arma::mat& iterate) | |
+ ArbitraryFunctionType& function, arma::mat& iterate) | |
{ | |
// Make sure that we have the methods that we need. | |
- traits::CheckNonDifferentiableDecomposableFunctionTypeAPI< | |
- DecomposableFunctionType>(); | |
+ // TODO: CheckArbitraryFunctionTypeAPI isn't implemented yet. | |
+// traits::CheckArbitraryFunctionTypeAPI<ArbitraryFunctionType>(); | |
arma::mat gradient(iterate.n_rows, iterate.n_cols); | |
arma::mat spVector(iterate.n_rows, iterate.n_cols); | |
@@ -55,69 +53,53 @@ inline double SPSA::Optimize( | |
double overallObjective = 0; | |
double lastObjective = DBL_MAX; | |
- const size_t actualMaxIterations = (maxIterations == 0) ? | |
- std::numeric_limits<size_t>::max() : maxIterations; | |
- for (size_t k = 0; k < actualMaxIterations; /* incrementing done manually */) | |
+ for (size_t k = 0; k < maxIterations; ++k) | |
{ | |
- // Is this iteration the start of a sequence? | |
- if (k > 0) | |
+ // Output current objective function. | |
+ Info << "SPSA: iteration " << k << ", objective " << overallObjective | |
+ << "." << std::endl; | |
+ | |
+ if (std::isnan(overallObjective) || std::isinf(overallObjective)) | |
+ { | |
+ Warn << "SPSA: converged to " << overallObjective << "; terminating" | |
+ << " with failure. Try a smaller step size?" << std::endl; | |
+ return overallObjective; | |
+ } | |
+ | |
+ if (std::abs(lastObjective - overallObjective) < tolerance) | |
{ | |
- // Output current objective function. | |
- Info << "SPSA: iteration " << k << ", objective " << overallObjective | |
- << "." << std::endl; | |
- | |
- if (std::isnan(overallObjective) || std::isinf(overallObjective)) | |
- { | |
- Warn << "SPSA: converged to " << overallObjective << "; terminating" | |
- << " with failure. Try a smaller step size?" << std::endl; | |
- return overallObjective; | |
- } | |
- | |
- if (std::abs(lastObjective - overallObjective) < tolerance) | |
- { | |
- Info << "SPSA: minimized within tolerance " << tolerance << "; " | |
- << "terminating optimization." << std::endl; | |
- return overallObjective; | |
- } | |
- | |
- // Reset the counter variables. | |
- lastObjective = overallObjective; | |
- | |
- if (shuffle) // Determine order of visitation. | |
- function.Shuffle(); | |
+ Warn << "SPSA: minimized within tolerance " << tolerance << "; " | |
+ << "terminating optimization." << std::endl; | |
+ return overallObjective; | |
} | |
+ // Reset the counter variables. | |
+ lastObjective = overallObjective; | |
+ | |
// Gain sequences. | |
- const double ak = stepSize / std::pow(k + 1 + Ak, alpha); | |
+ const double akLocal = stepSize / std::pow(k + 1 + ak, alpha); | |
const double ck = evaluationStepSize / std::pow(k + 1, gamma); | |
- gradient.zeros(); | |
- for (size_t b = 0; b < batchSize; b++) | |
- { | |
- // Stochastic directions. | |
- spVector = arma::conv_to<arma::mat>::from( | |
- arma::randi(iterate.n_rows, iterate.n_cols, | |
- arma::distr_param(0, 1))) * 2 - 1; | |
- | |
- iterate += ck * spVector; | |
- const double fPlus = function.Evaluate(iterate, 0, iterate.n_elem); | |
+ // Choose stochastic directions. | |
+ spVector = arma::conv_to<arma::mat>::from( | |
+ arma::randi(iterate.n_rows, iterate.n_cols, | |
+ arma::distr_param(0, 1))) * 2 - 1; | |
- iterate -= 2 * ck * spVector; | |
- const double fMinus = function.Evaluate(iterate, 0, iterate.n_elem); | |
- iterate += ck * spVector; | |
+ iterate += ck * spVector; | |
+ const double fPlus = function.Evaluate(iterate); | |
- gradient += (fPlus - fMinus) * (1 / (2 * ck * spVector)); | |
- } | |
+ iterate -= 2 * ck * spVector; | |
+ const double fMinus = function.Evaluate(iterate); | |
+ iterate += ck * spVector; | |
- gradient /= (double) batchSize; | |
- iterate -= ak * gradient; | |
+ gradient = (fPlus - fMinus) * (1 / (2 * ck * spVector)); | |
+ iterate -= akLocal * gradient; | |
- overallObjective = function.Evaluate(iterate, 0, iterate.n_elem); | |
- k += batchSize; | |
+ overallObjective = function.Evaluate(iterate); | |
} | |
// Calculate final objective. | |
- return function.Evaluate(iterate, 0, iterate.n_elem); | |
+ return function.Evaluate(iterate); | |
} | |
} // namespace ens | |
diff --git a/tests/spsa_test.cpp b/tests/spsa_test.cpp | |
index 8d946df..28cdbcf 100644 | |
--- a/tests/spsa_test.cpp | |
+++ b/tests/spsa_test.cpp | |
@@ -19,31 +19,13 @@ using namespace arma; | |
using namespace ens; | |
using namespace ens::test; | |
-/** | |
- * Test the SPSA optimizer on the SGDTest function. | |
- */ | |
-TEST_CASE("SPSASimpleSGDTestFunction","[SPSATest]") | |
-{ | |
- SGDTestFunction f; | |
- SPSA optimizer(0.1, 1, 0.102, 0.16, 0.3, 100000, 0); | |
- | |
- arma::mat coordinates = f.GetInitialPoint(); | |
- coordinates.ones(); | |
- double result = optimizer.Optimize(f, coordinates); | |
- | |
- REQUIRE(result == Approx(-1.0).epsilon(0.0005)); | |
- REQUIRE(coordinates[0] == Approx(0.0).margin(1e-3)); | |
- REQUIRE(coordinates[1] == Approx(0.0).margin(1e-7)); | |
- REQUIRE(coordinates[2] == Approx(0.0).margin(1e-7)); | |
-} | |
- | |
/** | |
* Test the SPSA optimizer on the Sphere function. | |
*/ | |
TEST_CASE("SPSASphereFunctionTest", "[SPSATest]") | |
{ | |
SphereFunction f(2); | |
- SPSA optimizer(0.1, 2, 0.102, 0.16, 0.3, 100000, 0); | |
+ SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 0); | |
arma::mat coordinates = f.GetInitialPoint(); | |
optimizer.Optimize(f, coordinates); | |
@@ -58,7 +40,7 @@ TEST_CASE("SPSASphereFunctionTest", "[SPSATest]") | |
TEST_CASE("SPSAMatyasFunctionTest", "[SPSATest]") | |
{ | |
MatyasFunction f; | |
- SPSA optimizer(0.1, 1, 0.102, 0.16, 0.3, 100000, 0); | |
+ SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 0); | |
arma::mat coordinates = f.GetInitialPoint(); | |
optimizer.Optimize(f, coordinates); | |
@@ -75,22 +57,33 @@ TEST_CASE("SPSAMatyasFunctionTest", "[SPSATest]") | |
*/ | |
TEST_CASE("SPSALogisticRegressionTest", "[SPSATest]") | |
{ | |
+ srand(std::time(NULL)); | |
+ arma::arma_rng::set_seed(std::time(NULL)); | |
arma::mat data, testData, shuffledData; | |
+ bool success = false; | |
arma::Row<size_t> responses, testResponses, shuffledResponses; | |
- LogisticRegressionTestData(data, testData, shuffledData, | |
- responses, testResponses, shuffledResponses); | |
- LogisticRegression<> lr(shuffledData, shuffledResponses, 0.5); | |
- | |
- SPSA optimizer(0.5, 1, 0.102, 0.16, 0.3, 100000, 1e-7); | |
- arma::mat coordinates = lr.GetInitialPoint(); | |
- optimizer.Optimize(lr, coordinates); | |
- | |
- // Ensure that the error is close to zero. | |
- const double acc = lr.ComputeAccuracy(data, responses, coordinates); | |
- REQUIRE(acc == Approx(100.0).epsilon(0.003)); // 0.3% error tolerance. | |
- | |
- const double testAcc = lr.ComputeAccuracy(testData, testResponses, | |
- coordinates); | |
- REQUIRE(testAcc == Approx(100.0).epsilon(0.006)); // 0.6% error tolerance. | |
+ for (size_t trial = 0; trial < 3; ++trial) | |
+ { | |
+ LogisticRegressionTestData(data, testData, shuffledData, | |
+ responses, testResponses, shuffledResponses); | |
+ LogisticRegression<> lr(shuffledData, shuffledResponses, 0.5); | |
+ | |
+ SPSA optimizer(0.5, 0.102, 0.002, 0.3, 1000, 1e-4); | |
+ arma::mat coordinates = lr.GetInitialPoint(); | |
+ optimizer.Optimize(lr, coordinates); | |
+ | |
+ // Ensure that the error is close to zero. | |
+ const double acc = lr.ComputeAccuracy(data, responses, coordinates); | |
+ const double testAcc = lr.ComputeAccuracy(testData, testResponses, | |
+ coordinates); | |
+ if (acc == Approx(100.0).epsilon(0.003) && | |
+ testAcc == Approx(100.0).epsilon(0.006)) | |
+ success = true; | |
+ | |
+ if (success) | |
+ break; | |
+ } | |
+ | |
+ REQUIRE(success == true); | |
} | |
-- | |
2.20.1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment