Skip to content

Instantly share code, notes, and snippets.

@ajasja
Created May 7, 2012 10:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ajasja/2627215 to your computer and use it in GitHub Desktop.
Save ajasja/2627215 to your computer and use it in GitHub Desktop.
NelderMeadMinimize`Dump`CompiledNelderMead
(* Produces compiled code for the Nelder-Mead algorithm with the objective function inlined. *)
(* The objective function takes the form F[parametersToOptimize..,constantParameters] *)
NelderMeadMinimize`Dump`CompiledNelderMead[
objectiveFunction_Function | objectiveFunction_CompiledFunction, vars : {__Symbol}, const: {__Symbol},
opts : OptionsPattern[NelderMeadMinimize`Dump`CompiledNelderMead]
] :=
NelderMeadMinimize`Dump`CompiledNelderMead[
objectiveFunction, vars, const
opts
] =
With[{
(* Inlined option values *)
historyLength = If[# === Automatic, 10 Length[vars], #] & @ OptionValue["HistoryLength"],
reflectRatio = OptionValue["ReflectRatio"], expandRatio = OptionValue["ExpandRatio"],
contractRatio = OptionValue["ContractRatio"], shrinkRatio = OptionValue["ShrinkRatio"],
(* Other inlined values *)
origin = ConstantArray[0., Length[vars]],
infinity = $MaxMachineNumber,
epsilon = $MachineEpsilon,
(* Inlined functions *)
f = apply[objectiveFunction, Evaluate[vars~Join~const]],
diffs = cumulativeAbsoluteDifferences,
(* Options to be passed to Compile *)
compileopts = Sequence @@ If[$VersionNumber >= 8, {
(* Mathematica 8 and above offer improved behaviour using these options *)
RuntimeOptions -> {"Speed", "CompareWithTolerance" -> True, "EvaluateSymbolically" -> False},
CompilationTarget -> OptionValue[CompilationTarget],
CompilationOptions->{"ExpressionOptimization"->True, "InlineCompiledFunctions"->Automatic, "InlineExternalDefinitions"->True}
}, {
(* Ordering is an external call in Mathematica 7 and so needs type information *)
{{_Ordering, _Integer, 1}}
}
]
},
Compile[{{pts, _Real, 2}, {cst, _Real, 1},{tol, _Real, 0}, {maxit, _Integer, 0}},
Block[{
(* Housekeeping *)
history = Table[infinity, {historyLength}], iteration = maxit,
(* Basic quantities *)
simplex = pts, vals = f[#~Join~cst]& /@ pts, ordering,
(* Calculated points and function values *)
centroid = origin,
reflectedPoint = origin, reflectedValue = infinity,
expandedPoint = origin, expandedValue = infinity,
contractedPoint = origin, contractedValue = infinity,
(* More readable indices into the simplex array *)
best = 1, worst = -1, rest = Rest@Range@Length[pts],
(* Operation counts (for debugging purposes) *)
evaluations = Length[pts],
reflections = 0, expansions = 0, contractions = 0, shrinkages = 0
},
While[
(* Order simplex points by function value *)
ordering = Ordering[vals];
vals = vals[[ordering]]; simplex = simplex[[ordering]];
(* Decrement and test iterator *)
(iteration--) != 0,
(* Check for convergence *)
history[[1]] = vals[[best]]; history = RotateLeft[history];
If[diffs[history] <= tol + epsilon diffs[history],
Break[]
];
(* Find centroid of first (N - 1) points *)
centroid = Mean@Most[simplex];
(* Reflect *)
reflectedPoint = centroid + reflectRatio (centroid - simplex[[worst]]);
reflectedValue = f[reflectedPoint~Join~cst]; ++evaluations;
If[vals[[best]] <= reflectedValue < vals[[-2]],
vals[[worst]] = reflectedValue; simplex[[worst]] = reflectedPoint;
++reflections; Continue[]
];
(* Expand *)
If[reflectedValue < vals[[best]],
expandedPoint = centroid + expandRatio (reflectedPoint - centroid);
expandedValue = f[expandedPoint~Join~cst]; ++evaluations;
If[expandedValue < reflectedValue,
vals[[worst]] = expandedValue; simplex[[worst]] = expandedPoint;
++expansions; Continue[],
vals[[worst]] = reflectedValue; simplex[[worst]] = reflectedPoint;
++reflections; Continue[]
];
];
(* Contract *)
If[reflectedValue < vals[[worst]],
(* Outside contraction *)
contractedPoint = centroid + contractRatio (reflectedPoint - centroid);
contractedValue = f[contractedPoint~Join~cst]; ++evaluations;
If[contractedValue <= reflectedValue,
vals[[worst]] = contractedValue; simplex[[worst]] = contractedPoint;
++contractions; Continue[]
];,
(* Inside contraction *)
contractedPoint = centroid - contractRatio (centroid - simplex[[worst]]);
contractedValue = f[contractedPoint~Join~cst]; ++evaluations;
If[contractedValue < vals[[worst]],
vals[[worst]] = contractedValue; simplex[[worst]] = contractedPoint;
++contractions; Continue[]
];
];
(* Shrink *)
simplex[[rest]] = simplex[[best]] + shrinkRatio (simplex[[rest]] - simplex[[best]]);
vals[[rest]] = f /@ simplex[[rest]];
evaluations += Length[rest] - 1;
++shrinkages;
];
(* A call out of the VM is necessary to return the results *)
(* results = {vals, simplex, {evaluations, reflections, expansions, contractions, shrinkages}};*)
First[simplex]~Join~{evaluations, reflections, expansions, contractions, shrinkages}
(*{evaluations, reflections, expansions, contractions, shrinkages}*)
], compileopts
]
];
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment