Skip to content

Instantly share code, notes, and snippets.

@jgillis
Created August 3, 2021 21:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jgillis/cc8b6ab7f1d174955dedd8ea7ab94a98 to your computer and use it in GitHub Desktop.
Save jgillis/cc8b6ab7f1d174955dedd8ea7ab94a98 to your computer and use it in GitHub Desktop.
import casadi
from casadi import *
from casadi.casadi import OPTI_INEQUALITY, Opti_bounded
import matplotlib.pyplot as plt
# https://epubs.siam.org/doi/pdf/10.1137/16M1062569
m1 = 1
m2 = 0.3
g = 9.81
l = 0.5 #pendulum length
dmax = 2.0 #max horiz distance
umax = 20.0 #max force
T = 2 #total time
d = 1 #horizontal distance
t0 = 0
tf = 2
plt.figure(1)
opti = casadi.Opti()
N = 25 # number of collocation points
h = (tf-t0) / (N-1) #step size
#states
x = opti.variable(4,N) #col vector of 4 vars
y1 = x[0,:] #cart pos
y2 = x[1,:] #pendulum position
ydot1 = x[2,:]
ydot2 = x[3,:]
u = opti.variable(1,N-1)
def f(x,u):
[y1,y2,ydot1,ydot2] =vertsplit(x)
dy1 = ydot1
dy2 = ydot2
dydot1 = ((l*m2*sin(y2)*ydot2**2) + u + (m2*g*cos(y2)*sin(y2))) / (m1 + m2*(1-cos(y2)**2))
dydot2 = -1*((l*m2*cos(y2)*sin(y2)*ydot2**2) + u*cos(y2) + ((m1+m2)*g*sin(y2))) / (l*m1 + l*m2*(1-cos(y2)**2))
return vertcat(dy1,dy2,dydot1,dydot2)
for k in range(N-1):
f1 = f(x[:,k+1],u[:,k])
f0 = f(x[:,k],u[:,k])
opti.subject_to(h/2*(f1+f0)==x[:,k+1]-x[:,k])
#path
opti.subject_to(Opti_bounded(-dmax,y1,dmax))
opti.subject_to(Opti_bounded(-umax,u,umax))
x_goal = vertcat(d,pi,0,0)
# Bound Constraints
opti.subject_to(x[:,0] == 0)
opti.subject_to(x[:,-1] == x_goal)
print(repmat(x_goal,1,N)*repmat(linspace(0,1,N).T,4,1))
#Guess
opti.set_initial(x,repmat(x_goal,1,N)*repmat(linspace(0,1,N).T,4,1)) #initial guess
sol = opti.minimize(sumsqr(u))
s_opts = {"ipopt.tol": 1e-6, "expand":True}
opti.solver('ipopt',s_opts); # set numerical backend
sol = opti.solve(); #actual solve
plt.figure(1)
t1 = linspace(t0,tf,N-1)
print(sol)
plt.plot(t1,sol.value(u))
plt.figure(2)
t1 = linspace(t0,tf,N)
print(sol.value(x).T)
plt.plot(t1,sol.value(x).T)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment