Skip to content

Instantly share code, notes, and snippets.

@Chachay
Created December 17, 2022 08:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chachay/da0adf4fefa22316ebbbfa7bcd12b613 to your computer and use it in GitHub Desktop.
Save Chachay/da0adf4fefa22316ebbbfa7bcd12b613 to your computer and use it in GitHub Desktop.
An example of cart pole swing up by SQP
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# SQPによる非線形モデル予測制御\n",
"[SQPによる非線形モデル予測制御 \\- ヤカンヒコウ](https://blog.chachay.org/2022/12/sqp-mpc.html) の実装例"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import sympy as sy\n",
"from sympy.physics import mechanics\n",
"from scipy.signal import cont2discrete\n",
"\n",
"from cvxpy import *\n",
"\n",
"from matplotlib.animation import ArtistAnimation\n",
"from IPython.display import HTML"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 台車型倒立振子のモデル\n",
"システム方程式のヤコビアン等の導出を簡単に行うためsympyを使った。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class cart_pole():\n",
" def __init__(self, l = .8, M =1., m=.1, dT=0.02, obs = 1):\n",
" \n",
" # x = [x, theta, x_dot, theta_dot]\n",
" self.x = np.zeros(4)\n",
" \n",
" self.l = l\n",
" self.M = M\n",
" self.m = m\n",
" \n",
" # 制御周期\n",
" self.dT = dT\n",
" # 制御周期あたりのシミュレーションステップ\n",
" self.obs = obs\n",
" \n",
" self.A, self.B = self.gen_lmodel()\n",
" \n",
" q = sy.symbols('q:{0}'.format(4))\n",
" u = sy.symbols('u')\n",
" self.calc_rhe = sy.lambdify([q,u], self.gen_rhe_sympy())\n",
" \n",
" def gen_rhe_sympy(self):\n",
" g = 9.8\n",
" l = self.l\n",
" M = self.M\n",
" m = self.m\n",
"\n",
" q = sy.symbols('q:{0}'.format(4))\n",
" qd = q[2:4]\n",
" u = sy.symbols('u')\n",
" \n",
" I = sy.Matrix([[1, 0, 0, 0], \n",
" [0, 1, 0, 0], \n",
" [0, 0, M + m, l*m*sy.cos(q[1])], \n",
" [0, 0, l*m*sy.cos(q[1]), l**2*m]])\n",
" f = sy.Matrix([\n",
" qd[0], \n",
" qd[1],\n",
" l*m*sy.sin(q[1])*qd[1]**2 + u,\n",
" -g*l*m*sy.sin(q[1])])\n",
" return sy.simplify(I.inv()*f)\n",
" \n",
" def gen_lmodel(self):\n",
" mat = self.gen_rhe_sympy()\n",
" q = sy.symbols('q:{0}'.format(4))\n",
" u = sy.symbols('u')\n",
" \n",
" A = mat.jacobian(q)\n",
" # 出力の次元が2次元以上のシステム(MIMO)ならばjacobian.\n",
" #B = mat.jacobian(u)\n",
" B = mat.diff(u)\n",
" \n",
" return (sy.lambdify([q,u], A),\n",
" sy.lambdify([q,u], B))\n",
" \n",
" def gen_dmodel(self, x, u, dT):\n",
" u = np.atleast_1d(u)\n",
" f = np.array(self.calc_rhe(x, u[0])).ravel()\n",
" A_c = np.array(self.A(x, u[0]))\n",
" B_c = np.array(self.B(x, u[0])).ravel()\n",
" \n",
" g_c = f - A_c@x - B_c*u\n",
"\n",
" B = np.vstack((B_c, g_c)).T\n",
"\n",
" A_d, B_d, _, _, _ = cont2discrete((A_c, B, 0, 0), dT)\n",
" g_d = B_d[:,1]\n",
" B_d = B_d[:,0]\n",
"\n",
" return A_d, B_d, g_d\n",
"\n",
" def step(self, u):\n",
" dT = self.dT / self.obs\n",
" \n",
" for _ in range(self.obs):\n",
" A, B, g = self.gen_dmodel(self.x, u, dT)\n",
" self.x = A@self.x + B * u +g"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 非線形モデル予測制御の実装"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NMPC():\n",
" def __init__(self, dT=0.02, time_horizon = 20, model=None, \n",
" x_ubounds=[], x_lbounds=[], \n",
" u_ubounds=[], u_lbounds=[]):\n",
" self.dT = dT\n",
" self.time_horizon = time_horizon\n",
" self.model = model\n",
"\n",
" self.x_min = np.asarray(x_lbounds,dtype=np.float64)\n",
" self.x_max = np.asarray(x_ubounds,dtype=np.float64)\n",
" self.u_min = np.asarray(u_lbounds,dtype=np.float64)\n",
" self.u_max = np.asarray(u_ubounds,dtype=np.float64)\n",
"\n",
" self.NX = self.x_min.shape[0]\n",
" self.NU = self.u_min.shape[0]\n",
"\n",
" # Dumping Param\n",
" self.alpha = 0.3\n",
" \n",
" def set_weight(self, Q, R, q, r):\n",
" self.Q = Q\n",
" self.R = R\n",
" self.q = q\n",
" self.r = r\n",
" \n",
" def solve_NMPC(self, x_guess, u_guess, x_ref, verbose=False, warmstart=False):\n",
" # SQP variables\n",
" x = Variable((self.NX, self.time_horizon+1))\n",
" u = Variable((self.NU, self.time_horizon))\n",
"\n",
" T = self.time_horizon\n",
" objective = 0\n",
"\n",
" # Define problem\n",
" ## Initial Condition\n",
" constraints = [x[:,0] == x_guess[:, 0]]\n",
"\n",
" for k in range(T):\n",
" objective += quad_form(x[:,k] - x_ref[:,k], self.Q) /2.\n",
" objective += quad_form(u[:,k], self.R) /2.\n",
" objective += (x_guess[:,k] - x_ref[:,k]).T@self.Q@x[:,k]\n",
" objective += u_guess[:,k].T@self.R@u[:,k]\n",
" \n",
" Ad, Bd, gd = self.model(x_guess[:,k], u_guess[:,k], self.dT)\n",
" constraints += [x[:,k+1]== Ad@x[:,k] + Bd*u[:,k] + gd]\n",
"\n",
" constraints += [self.x_min <= x[:,k], x[:,k] <= self.x_max]\n",
" constraints += [self.u_min <= u[:,k], u[:,k] <= self.u_max]\n",
"\n",
" ## Terminal Condition\n",
" objective += quad_form(x[:,T] - x_ref[:,T], self.Q)/2.\n",
" objective += (x_guess[:,T] - x_ref[:,T]).T@self.Q@x[:,T]\n",
" constraints += [self.x_min <= x[:,T] , x[:,T] <= self.x_max]\n",
"\n",
" prob = Problem(Minimize(objective), constraints)\n",
"\n",
" try:\n",
" prob.solve(solver=cvxpy.ECOS, verbose=verbose, max_iters=100, warm_start=warmstart)\n",
" except Exception as _:\n",
" return None, None, None\n",
"\n",
" if prob.status == cvxpy.OPTIMAL or prob.status == cvxpy.OPTIMAL_INACCURATE:\n",
" ret_x = self.alpha * x.value + (1 - self.alpha) * x_guess\n",
" ret_u = self.alpha * u.value + (1 - self.alpha) * u_guess\n",
" return ret_x, ret_u, prob.value\n",
" else:\n",
" return None, None, None"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## パラメータ等の設定"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prediction horizon\n",
"T = 20\n",
"dT = 0.05\n",
"sim_time = 20\n",
"\n",
"# Constants\n",
"Q = 2.0*np.diag([1., 100., 1., 10.])\n",
"R = 2.0*np.diag([1.])\n",
"q = np.zeros(4)\n",
"r = np.zeros(1)\n",
"\n",
"xmin = np.array([-3, -2*np.pi, -10, -10])\n",
"xmax = np.array([ 3, 2*np.pi, 10, 10])\n",
"\n",
"umin = np.array([ -10.])\n",
"umax = np.array([ 10.])\n",
"\n",
"model = cart_pole(dT=dT)\n",
"controller = NMPC(dT=dT, model=model.gen_dmodel,\n",
" x_lbounds=xmin, x_ubounds=xmax,\n",
" u_lbounds=umin, u_ubounds=umax, time_horizon=T)\n",
"\n",
"controller.set_weight(Q, R, q, r)\n",
"\n",
"# goal\n",
"x_ref = np.array([0, np.pi, 0, 0])\n",
"x_refs = np.tile(x_ref, (T+1, 1)).T"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 制御 + シミュレーションのループ"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def MPC_sim(x0):\n",
" # Trajectory Variables\n",
" u_guess = np.zeros((umin.shape[0], T), dtype=np.float64)\n",
" x_guess = np.tile(x0, (T+1, 1)).T\n",
" model.x = x0\n",
"\n",
" fig = plt.figure(figsize=(8, 8))\n",
" plt.axis('equal')\n",
" plt.xlim(-4, 4)\n",
"\n",
" # observation logger\n",
" o = [model.x]\n",
" ims = []\n",
"\n",
" for t in range(int(sim_time/dT)):\n",
"\n",
" if model.x[0] < -15 or 15 < model.x[0]:\n",
" break\n",
"\n",
" u_guess[:,0:-2] = u_guess[:,1:-1]\n",
" u_guess[:,-1] = np.array([[0]])\n",
"\n",
" x_guess[:,0:-2] = x_guess[:,1:-1]\n",
" x_guess[:,0] = model.x\n",
" x_guess[:,-1] = x_guess[:,-2]\n",
"\n",
" # MPC iteration\n",
" converge = False\n",
" for i in range(30):\n",
" x_tmp, u_tmp, _ = controller.solve_NMPC(x_guess, u_guess, x_refs)\n",
" \n",
" if u_tmp is not None:\n",
" if np.sum((u_guess[:,0]-u_tmp[:,0])**2) < 0.5:\n",
" converge = True\n",
" x_guess = x_tmp\n",
" u_guess = u_tmp\n",
"\n",
" if converge:\n",
" #print(\"\\tLoop:{} Break at iteration {}\".format(t, i))\n",
" break\n",
"\n",
" if t%3 == 0:\n",
" # animation\n",
" ## Cart Pos\n",
" im = plt.plot(model.x[0], 0, \"ro\", animated=True)\n",
" ## Pendulum\n",
" x = [model.x[0], model.x[0]+np.sin(model.x[1])*model.l]\n",
" y = [0, -np.cos(model.x[1])*model.l]\n",
" im += plt.plot(x, y, \"k-\", animated=True)\n",
" ims.append(im)\n",
"\n",
" # Sim Step\n",
" model.step(u_guess[0,0])\n",
" o.append(model.x)\n",
"\n",
" ani = ArtistAnimation(fig, ims, interval=dT*1000*3, repeat_delay=1000)\n",
" plt.close()\n",
" return o, ani"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Balancing quasi stable\n",
"LQRでも解けるほぼバランスしている状態からの制御"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# initial state\n",
"x0 = x_ref - np.array([0, 0.1, 0, 0])\n",
"obs, ani = MPC_sim(x0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"HTML(ani.to_html5_video())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Swing up"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# initial state\n",
"x0 = np.array([0., 0., 0., 0.])\n",
"obs, ani = MPC_sim(x0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"HTML(ani.to_html5_video())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mpcc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"vscode": {
"interpreter": {
"hash": "df5384cb4ec6a687391161e5d9b870394cde8c9f1f72f203fe1e596f0d090d6b"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment