Skip to content

Instantly share code, notes, and snippets.

@starwing
Created June 14, 2016 08:06
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save starwing/5e440a40ccb39986fac8260789d074dd to your computer and use it in GitHub Desktop.
Save starwing/5e440a40ccb39986fac8260789d074dd to your computer and use it in GitHub Desktop.
the Cassowary layout constraint algorithm in pure Lua
local function meta(name, parent)
t = {}
t.__name = name
t.__index = t
return setmetatable(t, parent)
end
local function approx(a, b)
if a > b then return a - b < 1e-6 end
return b - a < 1e-6
end
local function near_zero(n)
return approx(n, 0.0)
end
local function default(t, k, nv)
local v = t[k]
if not v then v = nv or {}; t[k] = v end
return v
end
local Variable, Expression, Constraint do
Variable = meta "Variable"
Expression = meta "Expression"
Constraint = meta "Constraint"
Constraint.REQUIRED = 1000*1000*1000.0
Constraint.STRONG = 1000*1000.0
Constraint.MEDIUM = 1000.0
Constraint.WEAK = 1.0
function Variable:__neg(other) return Expression.new(self, -1.0) end
function Expression:__neg() self:multiply(-1.0) end
function Variable:__add(other) return Expression.new(self) + other end
function Variable:__sub(other) return Expression.new(self) - other end
function Variable:__mul(other) return Expression.new(self) * other end
function Variable:__div(other) return Expression.new(self) / other end
function Expression:__add(other) return Expression.new(self):add(other) end
function Expression:__sub(other) return Expression.new(self):add(-other) end
function Expression:__mul(other) return Expression.new(self):multiply(other) end
function Expression:__div(other) return Expression.new(self):multiply(1.0/other) end
function Variable.new(name, type, id)
type = type or "external"
assert(type == "external" or
type == "slack" or
type == "error" or
type == "dummy", type)
local self = {
id = id,
name = name,
value = 0.0,
type = type or "external",
is_dummy = type == "dummy",
is_slack = type == "slack",
is_error = type == "error",
is_external = type == "external",
is_pivotable = type == "slack" or type == "error",
is_restricted = type ~= "external",
}
return setmetatable(self, Variable)
end
function Variable:__tostring()
return ("Var(%s): %s(%g)"):format(self.type, self.name, self.value)
end
function Expression.new(other, multiplier, constant)
local self = setmetatable({}, Expression)
return self:add(other, multiplier, constant)
end
function Expression:__tostring()
local t = { ("%g"):format(self.constant or 0.0) }
for k, v in self:iter_vars() do
if v > 0.0 then
t[#t+1] = (" + %g*%s"):format(v, k.name)
else
t[#t+1] = (" - %g*%s"):format(-v, k.name)
end
end
return "Exp: "..table.concat(t)
end
function Expression:add(other, multiplier, constant)
if other == nil then return self end
self.constant = (self.constant or 0.0) + (constant or 0.0)
multiplier = multiplier or 1.0
if tonumber(other) then
self.constant = self.constant + other*multiplier
return self
end
local mt = getmetatable(other)
if mt == Variable then
local multiplier = (self[other] or 0.0) + multiplier
self[other] = not near_zero(multiplier) and multiplier or nil
elseif mt == Expression then
for k, v in pairs(other) do
local multiplier = (self[k] or 0.0) + multiplier * v
self[k] = not near_zero(multiplier) and multiplier or nil
end
self.constant = self.constant or 0.0
else
error("constant/variable/expression expected")
end
return self
end
function Expression:multiply(other)
if tonumber(other) then
for k, v in pairs(self) do
self[k] = v * other
end
return self
end
local mt = getmetatable(other)
if mt == Variable then
return self:multiply(Expression.new(other))
elseif mt == Expression then
if other:is_constant() then
return self:multiply(other.constant)
elseif self.constant then
local constant = self.constant
self.constant = 0.0
return self:add(other):multiply(constant)
end
error("attempt to multiply two non-constant expression")
else
error("number/variable/constant expression expected")
end
end
function Expression:choose_pivotable()
for k, v in pairs(self) do
if k.is_pivotable then
return k
end
end
end
function Expression:is_constant()
for k, v in self:iter_vars() do
return false
end
return true
end
function Expression:solve_for(new, old)
-- expr: old == a[n] *new + constant + a[i]* v[i]...
-- => new == (1/a[n])*old - 1/a[n]*constant - (1/a[n])*a[i]*v[i]...
local multiplier = assert(self[new])
assert(new ~= old and not near_zero(multiplier))
self[new] = nil
local reciprocal = 1.0 / multiplier
self:multiply(-reciprocal)
if old then self[old] = reciprocal end
return new
end
function Expression:substitute_out(var, expr)
assert(var ~= "constant")
local multiplier = self[var]
if not multiplier then return end
self[var] = nil
self:add(expr, multiplier)
end
function Expression:iter_vars()
return function(self, k)
local k, v = next(self, k)
if k == 'constant' then
return next(self, k)
end
return k, v
end, self
end
function Constraint.new(op, expr1, expr2, strength)
assert(op == '==' or
op == '<=' or
op == '>=', "op must be '==', '>=' or '<='")
local self = {}
if op == '<=' then
self.expression = Expression.new(expr2 or 0.0):add(expr1, -1.0)
else
self.expression = Expression.new(expr1 or 0.0):add(expr2, -1.0)
end
self.is_inequality = op ~= '=='
return setmetatable(self, Constraint):set_strength(strength or Constraint.REQUIRED)
end
function Constraint.equal(expr1, expr2, strength)
return Constraint.new('==', expr1, expr2, strength)
end
function Constraint.less_equal(expr1, expr2, strength)
return Constraint.new('<=', expr1, expr2, strength)
end
function Constraint.great_equal(expr1, expr2, strength)
return Constraint.new('>=', expr1, expr2, strength)
end
function Constraint:set_strength(strength)
self.strength = Constraint[strength] or tonumber(strength) or self.strength
self.is_required = self.strength >= Constraint.REQUIRED
return self
end
function Constraint:clone(strength)
local new = {}
new.type = self.type
new.expression = Expression.new(self.expression)
new.is_inequality = self.is_inequality
return setmetatable(new, Constraint):set_strength(strength or self.strength)
end
end
local SimplexSolver = meta "SimplexSolver" do
-- implements
local function update_external_variables(self)
for var in pairs(self.vars) do
local row = self.rows[var]
var.value = row and row.constant or 0.0
end
end
local function substitute_out(self, var, expr)
for k, row in pairs(self.rows) do
row:substitute_out(var, expr)
if k.is_restricted and row.constant < 0.0 then
self.infeasible_rows[#self.infeasible_rows+1] = k
end
end
self.objective:substitute_out(var, expr)
end
local function optimize(self, objective)
objective = objective or self.objective
while true do
local entry, exit
for var, multiplier in objective:iter_vars() do
if not var.is_dummy and multiplier < 0.0 then
entry = var
break
end
end
if not entry then return end
local r = 0.0
local min_ratio = math.huge
for var, row in pairs(self.rows) do
local multiplier = row[entry]
if multiplier and var.is_pivotable and multiplier < 0.0 then
r = -row.constant / multiplier
if r < min_ratio or (approx(r, min_ratio) and
var.id < exit.id) then
min_ratio, exit = r, var
end
end
end
assert(exit, "objective function is unbounded")
-- do pivot
local row = self.rows[exit]
self.rows[exit] = nil
row:solve_for(entry, exit)
substitute_out(self, entry, row)
if objective ~= self.objective then
objective:substitute_out(entry, row)
end
self.rows[entry] = row
end
end
local function make_variable(self, type)
local id = self.last_varid
self.last_varid = id + 1
local prefix = type == "eplus" and "ep" or
type == "eminus" and "em" or
type == "dummy" and "d" or
type == "artificial" and "a" or "s"
if not type or type == "artificial" then
type = "slack"
elseif type == "eplus" or type == "eminus" then
type = "error"
end
return Variable.new(prefix..id, type, id)
end
local function make_expression(self, cons)
local expr = Expression.new(cons.expression.constant)
local var1, var2
for k, v in cons.expression:iter_vars() do
if not k.id then
k.id = self.last_varid
self.last_varid = k.id + 1
end
if not self.vars[k] then
self.vars[k] = true
end
expr:add(self.rows[k] or k, v)
end
if cons.is_inequality then
var1 = make_variable(self) -- slack
expr[var1] = -1.0
if not cons.is_required then
var2 = make_variable(self, "eminus")
expr[var2] = 1.0
self.objective[var2] = cons.strength
end
elseif cons.is_required then
var1 = make_variable(self, 'dummy')
expr[var1] = 1.0
else
var1 = make_variable(self, 'eplus')
var2 = make_variable(self, 'eminus')
expr[var1] = -1.0
expr[var2] = 1.0
self.objective[var1] = cons.strength
self.objective[var2] = cons.strength
end
if expr.constant < 0.0 then expr:multiply(-1.0) end
return expr, var1, var2
end
local function choose_subject(self, expr, var1, var2)
for k, v in expr:iter_vars() do
if k.is_external then return k end
end
if var1 and var1.is_pivotable then return var1 end
if var2 and var2.is_pivotable then return var2 end
for k, v in expr:iter_vars() do
if not k.is_dummy then return nil end -- no luck
end
if not near_zero(expr.constant) then
return nil, "unsatisfiable required constraint added"
end
return var1
end
local function add_with_artificial_variable(self, expr)
local a = make_variable(self, 'artificial')
self.rows[a] = expr
optimize(self, expr)
local row = self.rows[a]
self.rows[a] = nil
local success = near_zero(expr.constant)
if row then
if row:is_constant() then
return success
end
local entering = row:choose_pivotable()
if not entering then return false end
row:solve_for(entering, a)
self.rows[entering] = row
end
for var, row in pairs(self.rows) do row[a] = nil end
self.objective[a] = nil
return success
end
local function get_marker_leaving_row(self, marker)
local r1, r2 = math.huge, math.huge
local first, second, third
for var, row in pairs(self.rows) do
local multiplier = row[marker]
if multiplier then
if var.is_external then
third = var
elseif multiplier < 0.0 then
local r = -row.constant / multiplier
if r < r1 then r1 = r; first = var end
else
local r = row.constant / multiplier
if r < r2 then r2 = r; second = var end
end
end
end
return first or second or third
end
local function delta_edit_constant(self, delta, var1, var2)
local row = self.rows[var1]
if row then
row.constant = row.constant - delta
if row.constant < 0.0 then
self.infeasible_rows[#self.infeasible_rows+1] = var1
end
return
end
local row = self.rows[var2]
if row then
row.constant = row.constant + delta
if row.constant < 0.0 then
self.infeasible_rows[#self.infeasible_rows+1] = var2
end
return
end
for var, row in pairs(self.rows) do
row.constant = row.constant + (row[var1] or 0.0)*delta
if var.is_restricted and row.constant < 0.0 then
self.infeasible_rows[#self.infeasible_rows+1] = var
end
end
end
local function dual_optimize(self)
while true do
local count = #self.infeasible_rows
if count == 0 then return end
local exit = self.infeasible_rows[count]
self.infeasible_rows[count] = nil
local row = self.rows[exit]
if row then
local entry
local min_ratio = math.huge
for var, multiplier in row:iter_vars() do
if not var.is_dummy and multiplier > 0.0 then
local r = (self.objective[var] or 0.0) / multiplier
if r < min_ratio then
min_ratio, entry = r, var
end
end
end
assert(entry, "dual optimize failed")
-- pivot
self.rows[exit] = nil
row:solve_for(entry, exit)
substitute_out(self, entry, row)
self.rows[entry] = row
end
end
end
-- interface
function SimplexSolver:has_variable(var) return self.vars[var] end
function SimplexSolver:has_constraint(cons) return self.constraints[cons] end
function SimplexSolver:has_edit_var(var) return self.edits[var] end
function SimplexSolver.new()
local self = {}
self.last_varid = 1
self.vars = {}
self.constraints = {}
self.objective = Expression.new()
self.rows = {}
self.infeasible_rows = {}
self.edits = {}
self.stays = {}
return setmetatable(self, SimplexSolver)
end
function SimplexSolver:__tostring()
local t = { " ----- SimplexSolver info -----\n" }
t[#t+1] = (" objective %s\n"):format(tostring(self.objective))
if next(self.rows) then
t[#t+1] = " rows:\n"
local idx = 1
for k, v in pairs(self.rows) do
t[#t+1] = (" %d. %s(%g):\t%s\n"):format(idx, k.name, k.value, tostring(v))
idx = idx + 1
end
end
if next(self.edits) then
t[#t+1] = " edits:\n"
local idx = 1
for k, v in pairs(self.edits) do
t[#t+1] = (" %d. %s(%s) { %s, %s, %g }\n"):format(
idx, k.name, k.value, tostring(v.plus), tostring(v.minus),
v.prev_constant)
idx = idx + 1
end
end
if #self.infeasible_rows ~= 0 then
t[#t+1] = " infeasible_rows:"
for _, var in ipairs(self.infeasible_rows) do
t[#t+1] = (" %s"):format(var.name)
end
t[#t+1] = "\n"
end
if #self.vars ~= 0 then
t[#t+1] = " vars:"
for var in pairs(self.vars) do
t[#t+1] = (" %s"):format(var.name)
end
t[#t+1] = "\n"
end
t[#t+1] = " ----------------------------\n"
return table.concat(t)
end
function SimplexSolver:add_constraint(cons, strength)
if strength then cons = Constraint:clone(strength) end
if self.constraints[cons] then return cons end
local expr, var1, var2 = make_expression(self, cons)
local subject, err = choose_subject(self, expr, var1, var2)
if subject then
expr:solve_for(subject)
substitute_out(self, subject, expr)
self.rows[subject] = expr
elseif err then
return nil, err
elseif not add_with_artificial_variable(self, expr) then
return nil, "constraint added may unbounded"
end
self.constraints[cons] = {
marker = var1,
other = var2,
}
optimize(self)
update_external_variables(self)
return cons
end
function SimplexSolver:remove_constraint(cons)
local info = self.constraints[cons]
if not info then return end
self.constraints[cons] = nil
if info.marker and info.marker.is_error then
self.objective:add(self.rows[info.marker] or info.marker, -cons.strength)
end
if info.other and info.other.is_error then
self.objective:add(self.rows[info.other] or info.other, -cons.strength)
end
if self.objective:is_constant() then
self.objective.constant = 0.0
end
local row = self.rows[info.marker]
if row then
self.rows[info.marker] = nil
else
local var = assert(get_marker_leaving_row(self, info.marker),
"failed to find leaving row")
local row = self.rows[var]
self.rows[var] = nil
row:solve_for(info.marker, var)
substitute_out(self, info.marker, row)
end
optimize(self)
update_external_variables(self)
return cons
end
function SimplexSolver:add_edit_var(var, value, strength)
if self.edits[var] then return end
strength = strength or Constraint.STRONG
assert(strength < Constraint.REQUIRED, "attempt to edit a required var")
value = value or var.value or 0.0
local cons = Constraint.new("==", var, value, strength)
assert(self:add_constraint(cons))
local info = self.constraints[cons]
self.edits[var] = {
constraint = cons,
plus = info.marker,
minus = info.other,
prev_constant = value,
}
return self
end
function SimplexSolver:remove_edit_var(var)
local info = self.edits[var]
if info then
self:remove_constraint(info.constraint)
self.edits[var] = nil
end
end
function SimplexSolver:suggest_value(var, value)
local info = self.edits[var]
if not info then return self:add_edit_var(var, value) end
local delta = value - info.prev_constant
info.prev_constant = value
delta_edit_constant(self, delta, info.plus, info.minus)
dual_optimize(self)
update_external_variables(self)
end
function SimplexSolver:chaneg_strength(cons, s)
local info = self.constraints[cons]
if not info then return self:add_constraint(cons, s) end
assert(info.marker and info.marker.is_error, "attempt to change required strength")
local diff = strength - cons.strength
cons:set_strength(strength)
if near_zero(diff) then return self end
self.objective:add(self.rows[info.marker] or info.marker, diff)
self.objective:add(self.rows[info.other] or info.other, diff)
optimize(self)
update_external_variables(self)
return self
end
function SimplexSolver:resolve()
dual_optimize(self)
set_external_variables()
reset_stay_constant(self)
self.infeasible_rows = {}
end
end
return {
V = Variable, C = Constraint,
Variable = Variable,
Constraint = Variable,
SimplexSolver = SimplexSolver,
}
local C = require "cassowary"
local Variable, Constraint, SimplexSolver = C.V, C.C, C.SimplexSolver
local solver = SimplexSolver.new()
local xl = Variable.new "xl"
local xm = Variable.new "xm"
local xr = Variable.new "xr"
local c1 = assert(solver:add_constraint(Constraint.new("==", 2*xm, xl + xr)))
local c2 = assert(solver:add_constraint(Constraint.new("<=", xl + 10, xr)))
local c3 = assert(solver:add_constraint(Constraint.new("<=", xr, 100)))
local c4 = assert(solver:add_constraint(Constraint.new("<=", 0, xl)))
solver:remove_constraint(c1)
solver:remove_constraint(c2)
solver:remove_constraint(c3)
solver:remove_constraint(c4)
solver:add_constraint(c1)
solver:add_constraint(c2)
solver:add_constraint(c3)
solver:add_constraint(c4)
print(("+"):rep(78), "after initialize", solver)
solver:add_edit_var(xm)
print(("+"):rep(78), "after add edit var", solver)
solver:suggest_value(xm, 0)
print(("+"):rep(78), "after suggest_value to 0", solver)
solver:suggest_value(xm, 70)
print(("+"):rep(78), "after suggest_value to 70", solver)
solver:remove_edit_var(xm)
print(("+"):rep(78), "after delete edit var", solver)
print(xl, xm, xr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment