Skip to content

Instantly share code, notes, and snippets.

@baggepinnen
Created November 18, 2021 15:10
Show Gist options
  • Save baggepinnen/91c0f688fefa204e0f012d78d4d7d878 to your computer and use it in GitHub Desktop.
Save baggepinnen/91c0f688fefa204e0f012d78d4d7d878 to your computer and use it in GitHub Desktop.
using StaticArrays
using Statistics, LinearAlgebra
using ModelingToolkit, Symbolics
"""
rk4(f, l, Ts)
Discretize dynamics `f` and loss function `l`using RK4 with sample time `Ts`.
The returned function is on the form `(xₖ,uₖ,t)-> (xₖ₊₁, loss)`.
Both `f` and `l` take the arguments `(x, u, t)`.
"""
function rk4(f::F, l::LT, Ts) where {F, LT}
# Runge-Kutta 4 method
function (x, u, t)
f1, L1 = f(x, u, t), l(x, u, t)
f2, L2 = f(x + Ts / 2 * f1, u, t + Ts / 2), l(x + Ts / 2 * f1, u, t + Ts / 2)
f3, L3 = f(x + Ts / 2 * f2, u, t + Ts / 2), l(x + Ts / 2 * f2, u, t + Ts / 2)
f4, L4 = f(x + Ts * f3, u, t + Ts), l(x + Ts * f3, u, t + Ts)
x += Ts / 6 * (f1 + 2 * f2 + 2 * f3 + f4)
L = Ts / 6 * (L1 + 2 * L2 + 2 * L3 + L4)
return x, L
end
end
function cartpole(x, u, _)
mc, mp, l, g = 1.0, 0.2, 0.5, 9.81
q = x[SA[1, 2]]
qd = x[SA[3, 4]]
s = sin(q[2])
c = cos(q[2])
H = @SMatrix [mc+mp mp*l*c; mp*l*c mp*l^2]
C = @SMatrix [0 -mp*qd[2]*l*s; 0 0]
G = @SVector [0, mp * g * l * s]
B = @SVector [1, 0]
qdd = -H \ (C * qd + G - B * u[1])
return [qd; qdd]
end
function loss(x,u,t)
c = x'Q1*x + u'Q2*u
# if Q3 !== nothing
# Δu = u - uprev
# c += dot(Δu,Q3,Δu)
# end
c
end
function final_cost(x)
x'Q1*x # TODO: replace by Riccati solution
end
nu = 1 # number of controls
nx = 4 # number of states
Ts = 0.02 # sample time
N = 2 # Time horizon (set very small to not take too long time generating symbolic functions) realistic N are in the hundreds
x0 = zeros(nx) # Initial state
x0[1] = 3 # cart pos
x0[2] = pi*0.5 # pendulum angle
xr = zeros(nx) # reference state
Q1 = diagm(Float64[1, 1, 1, 1]) # state cost matrix
Q2 = Ts * diagm(ones(nu)) # control cost matrix
Q3 = nothing
# Control limits
umin = -10 * ones(nu)
umax = 10 * ones(nu)
# State limits (be careful with those, they may make the problem infeasible)
xmin = -50 * ones(nx)
xmax = 50 * ones(nx)
discrete = rk4(cartpole, loss, Ts) # discretize the loss integral and continupus dynamics
# xp, L = discrete(x, u, t)
## Build symbolic representation of optimal control problem.
w = [] # variables
w0 = [] # initial guess
lbw = [] # lower bound on w
ubw = [] # upper bound on w
g = [] # equality constraints
L = 0
@variables x[1:nx](1) # initial value variable
x = collect(x) # Symbolic arrays are too buggy
append!(w, x)
append!(w0, x0)
append!(lbw, x) # Initial state is fixed
append!(ubw, x)
for n = 1:N # for whole time horizon N
global x, L
@variables u[1:nu](n)
u = collect(u) # Symbolic arrays are too buggy
append!(w, u)
append!(w0, 0) # TODO: add u0
append!(lbw, umin)
append!(ubw, umax)
xp, l = discrete(x, u, n)
L += l
@variables x[1:4](n+1) # x in next time point
x = collect(x) # Symbolic arrays are too buggy
append!(w, x)
append!(w0, zeros(nx)) # TODO: add warmstart
append!(lbw, xmin)
append!(ubw, xmax)
append!(g, xp .- x) # propagated x is x in next time point
L += final_cost(x)
end
##
# J = Symbolics.jacobian(xp, x)
@time A = Symbolics.sparsejacobian(g, w); # This takes forever to print in full form
@time dw = Symbolics.gradient(L, w)
@time H = Symbolics.sparsehessian(L, w)
@time jacfun = build_function(A, w, expression = Val(true)); # takes forever
# 222.257406 seconds (963.52 M allocations: 40.937 GiB, 15.66% gc time, 0.41% compilation time)
@time hessfun = build_function(H, w, expression = Val(false))
lbfun = build_function(lbw, w, expression = Val(false))
ubfun = build_function(ubw, w, expression = Val(false))
res = lbfun[1](w0)
@test res[1:nx] == w0[1:nx]
@time jacfun = build_function(A.nzval, w, expression = Val(false));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment