Skip to content

Instantly share code, notes, and snippets.

@jgillis
Last active June 4, 2024 08:05
Show Gist options
  • Save jgillis/d70a9f743bd5b9e94bc091683725ca50 to your computer and use it in GitHub Desktop.
Save jgillis/d70a9f743bd5b9e94bc091683725ca50 to your computer and use it in GitHub Desktop.
from casadi import *
# Example on how to use the DaeBuilder class
# Joel Andersson, UW Madison 2017
# Start with an empty DaeBuilder instance
dae = DaeBuilder('rocket')
# Add input expressions
a = dae.add_p('a')
b = dae.add_p('b')
u = dae.add_u('u')
h = dae.add_x('h')
v = dae.add_x('v')
m = dae.add_x('m')
f = Function('f',[h],[h**2])
cg = CodeGenerator('f')
cg.add(f)
cg.add(f.forward(1))
cg.generate()
f = external('f',Importer('f.c','shell'))
g = Function('g',[h],[cos(h)])
cg = CodeGenerator('g')
cg.add(g)
cg.add(g.forward(1))
cg.generate()
g = external('g',Importer('g.c','shell'))
# Set ODE right-hand-side
dae.set_ode('h', v)
dae.set_ode('v', (u-a*v**2)/m-9.81)
dae.set_ode('m', -b*u**2+g(f(h)+sqrt(v)))
# Specify initial conditions
dae.set_start('h', 0)
dae.set_start('v', 0)
dae.set_start('m', 1)
# Add meta information
dae.set_unit('h','m')
dae.set_unit('v','m/s')
dae.set_unit('m','kg')
ode = dae.create('ode',['u','x','p'],['ode'])
J = dae.create('ode',['u','x','p'],['jac_ode_x'])
kwargs = dict(u=0.3,x=vertcat(0.1,0.2,0.3),p=vertcat(1,1))
print(ode(**kwargs))
# Print DAE
dae.disp(True)
dae.lift(False,True)
dae.disp(True)
f_w = dae.dependent_fun('f_w',['u','x','p'],['w'])
# w = wdef(w,x)
wdef = dae.create('ode',['u','x','p','w'],['wdef'])
print(f_w(**kwargs)['w'])
print(wdef(**kwargs,w=f_w(**kwargs)['w']))
# dx = ode(w,x)
ode = dae.create('ode',['u','x','p','w'],['ode'])
print(ode(**kwargs,w=f_w(**kwargs)['w']))
f = dae.create('f',['u','x','p','w'],['jac_ode_x','jac_ode_w','jac_wdef_w','jac_wdef_x'],True,True)
#f.disp(True)
#print(f.find_functions())
res = f(u=0.3,x=vertcat(0.1,0.2,0.3),p=vertcat(1,1),w=f_w(**kwargs)['w'])
# w = wdef(w,x)
# xdot = ode(w,x)
# dw/dx = d_wdef/dw * dw/dx + d_wdef/dx
# d_xdot_x = d_ode/dw * dw/dx + d_ode/dx
# dw/dx - d_wdef/dw * dw/dx = d_wdef/dx
# d_xdot_x = d_ode/dw * dw/dx + d_ode/dx
# (I - d_wdef/dw) dw/dx = d_wdef/dx
# d_xdot_x = d_ode/dw * dw/dx + d_ode/dx
print(res["jac_ode_x"] + res["jac_ode_w"] @ solve(DM.eye(res["jac_wdef_w"].shape[0])-res["jac_wdef_w"],res["jac_wdef_x"]))
print(J(**kwargs))
#jac_ode_x + jac_ode_w @ jac_wdef_w @ jac_w_x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment