Skip to content

Instantly share code, notes, and snippets.

@jgillis
Last active March 8, 2016 08:27
Show Gist options
  • Save jgillis/c9467993764b201982bb to your computer and use it in GitHub Desktop.
Save jgillis/c9467993764b201982bb to your computer and use it in GitHub Desktop.
enforcing P>0
psdchecker = PSDchecker(4); % if you have 4 states
psdchecker_fun = psdchecker.create();
Now you can call `psdchecker_fun` on an `MX` node e.g.
P = MX.sym('P',4,4)
out = psdchecker_fun({P})
objective = objective + out{1};
classdef mydergen < casadi.DerivativeGenerator2
properties
fwd
n
end
methods
function self = mydergen(n,fwd)
self.fwd = fwd;
self.n = n;
end
function out = paren(self,fcn,ndir)
import casadi.*
% Obtain the symbols for nominal inputs/outputs
nominal_in = fcn.symbolicInput();
nominal_out = fcn.symbolicOutput();
der_ins = {nominal_in{:}, nominal_out{:}};
der_outs = {};
for i=1:ndir
if self.fwd
der_ins = {der_ins{:} MX.sym('x',self.n,self.n)};
der_outs = {der_outs{:} MX.zeros(1,1)};
else
der_ins = {der_ins{:} MX.sym('x')};
der_outs = {der_outs{:} MX.zeros(self.n,self.n)};
end
end
out = MXFunction('my_derivative', der_ins, der_outs);
end
end
end
classdef PSDchecker < casadi.Callback2
properties
n
fwd
adj
end
methods
function self = PSDchecker(n)
self.n = n;
self.fwd = mydergen(n,true);
self.adj = mydergen(n,false);
end
function argout = paren(self,argin)
P = full(argin{1});
if (min(eig(P))<0) % figure out the test here
argout{1} = inf;
else
argout{1} = 0;
end
end
function out = inputShape(self,i)
out = [self.n, self.n];
end
function out = options(self)
out = struct('custom_forward',self.fwd.create(), 'custom_reverse',self.adj.create());
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment