Skip to content

Instantly share code, notes, and snippets.

@mwgamera
Last active May 5, 2020 09:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mwgamera/5077246 to your computer and use it in GitHub Desktop.
Save mwgamera/5077246 to your computer and use it in GitHub Desktop.
Simple Lua implementation of Nelder-Mead downhill simplex optimization algorithm.
#!/usr/bin/env luajit
local NMSolver = require "nmsolver"
local Vector = require "vector"
-- Himmelblau's test function
local hfun = function(X)
local x, y = X[1], X[2]
return (x^2+y-11)^2 + (x+y^2-7)^2
end
-- random initial simplex in (-4,4)²
local X0
do
math.randomseed(tonumber(arg[1] or 1021))
local function randomvertex()
return Vector:new {
8 * math.random() - 4,
8 * math.random() - 4
}
end
X0 = {
randomvertex(),
randomvertex(),
randomvertex()
}
end
local sim = NMSolver:new(hfun, X0)
local function dumpsv(n, sim)
local v = sim:vertices()
local f = io.open(string.format("%03u.dat", n), "w")
for i = 1, #v do
f:write(string.format("%f %f\n", v[i][1], v[i][2]))
end
f:write(string.format("%f %f\n", v[1][1], v[1][2]))
f:close()
end
for i = 1, 22 do
dumpsv(i, sim)
sim:next()
end
# vim: ft=gnuplot :
hfun(x,y) = (x**2+y-11)**2 + (x+y**2-7)**2
set xrange [-4:4]
set yrange [-4:4]
set sample 500
set isosample 500, 500
set table 'himg.dat'
splot hfun(x, y)
unset table
set contour base
set cntrparam level incremental 1, 10, 200
unset surface
set table 'hcont.dat'
splot hfun(x,y)
unset table
set size square
unset key
set term pngcairo enhanced font 'Linux Biolinum, 11pt' size 500, 420
set palette model RGB
set palette rgbformulae 8,2,2
set style line 1 lw 0.83 lc rgb '#999999'
set style line 2 lw 1.75 lc rgb '#d91a80'
i = 1
while (i < 23) {
datfile = sprintf("%03u.dat", i)
print datfile
set out sprintf("%03u.png", i)
plot 'himg.dat' with image, 'hcont.dat' w l ls 1, datfile w l ls 2
i = i + 1
}
!rm 'himg.dat' 'hcont.dat'
-- Nelder-Mead downhill simplex optimization algorithm
local Simplex = {
alpha = 1, gamma = 2,
rho = -0.5, sigma = 0.5
}
-- Create solver for given function and initial simplex
function Simplex:new(fun, sim, o)
self.__index = self
local s = setmetatable(o or {}, self)
s.obj = fun
for i = 1, #sim do
s[i] = { sim[i], fun(sim[i]) }
end
return s
end
-- Sort vertices by their objective function
function Simplex:sort()
table.sort(self, function(a, b)
return a[2] < b[2]
end)
return self
end
-- Perform Nelder-Mead minimization step and derive next simplex
function Simplex:next()
-- barycenter of all but worst
self:sort()
local m = 1. / (#self - 1)
local c = m * self[1][1]
for i = 2, #self-1 do
c = c + m * self[i][1]
end
-- reflected point
local rx = c + self.alpha * (c - self[#self][1])
local ry = self.obj(rx)
-- expansion
if ry < self[1][2] then
local ex = c + self.gamma * (c - self[#self][1])
local ey = self.obj(ex)
if ey < ry then
self[#self] = {ex, ey}
return self
end
end
-- reflection
if ry < self[#self-1][2] then
self[#self] = {rx, ry}
return self
end
-- contraction
local cx = c + self.rho * (c - self[#self][1])
local cy = self.obj(cx)
if cy < self[#self][2] then
self[#self] = {cx, cy}
return self
end
-- reduction
for i = 2, #self do
self[i][1] = self[1][1] + self.sigma * (self[i][1] - self[1][1])
self[i][2] = self.obj(self[i][1])
end
return self
end
-- Get array of vertices
function Simplex:vertices()
local v = {}
for i = 1, #self do
v[i] = self[i][1]
end
return v
end
-- Get best vertex, i.e. of minimal objective value
function Simplex:best()
return unpack(self:sort()[1])
end
return Simplex
local Vector = {}
function Vector:new(x)
x = x or {}
setmetatable(x, self)
self.__index = self
return x
end
function Vector.__add(x, y)
local r = {}
for i = 1, #x do
r[i] = x[i] + y[i]
end
return Vector:new(r)
end
function Vector.__mul(x, y)
if type(x) == "number" then
x, y = y, x
end
assert(type(y) == "number")
local r = {}
for i = 1, #x do
r[i] = x[i] * y
end
return Vector:new(r)
end
function Vector.__eq(x, y)
for i = 1, #x do
if x[i] ~= y[i] then
return false
end
end
return true
end
function Vector.__sub(x, y)
local r = {}
for i = 1, #x do
r[i] = x[i] - y[i]
end
return Vector:new(r)
end
function Vector.__unm(x)
return x * -1
end
function Vector.__div(x, y)
assert(type(y) == "number")
return x * (1. / y)
end
function Vector:__tostring()
return "{".. table.concat(self, ",") .."}"
end
return Vector
@fsantini
Copy link

Hi! I would find this code extremely useful for my project. Can you please tell me which license you release it with? Is it MIT/X compatible like Lua? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment