Created
October 1, 2019 02:02
-
-
Save stephentu/8d801e2601945fac451cc659a31c1245 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import scipy\n", | |
"import scipy.linalg" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rng = np.random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n, d = 3, 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def spectral_radius(M):\n", | |
" return max(np.abs(np.linalg.eigvals(M)))\n", | |
"\n", | |
"\n", | |
"def operator_norm(M):\n", | |
" return np.linalg.norm(M, ord=2)\n", | |
"\n", | |
"\n", | |
"def random_contractive_system(n, d, rng):\n", | |
" B = rng.normal(size=(n, d))\n", | |
" \n", | |
" A = rng.normal(size=(n, n))\n", | |
" A /= operator_norm(A)\n", | |
" return 0.99 * A, B\n", | |
" \n", | |
"\n", | |
"def dlyap(A, Q):\n", | |
" # newer versions of scipy solve A P A^T - P + Q = 0,\n", | |
" # while older ones solve A^T P A - P + Q = 0. I do not\n", | |
" # remember exactly which version of scipy made the change.\n", | |
"\n", | |
" # I am going to assume you have the newer version installed.\n", | |
" # If the assertion below fails, please add an if statement\n", | |
" # that branches on your version.\n", | |
"\n", | |
" P = scipy.linalg.solve_discrete_lyapunov(A.T, Q)\n", | |
"\n", | |
" assert np.allclose(A.T.dot(P).dot(A) - P, -Q)\n", | |
"\n", | |
" return P\n", | |
" \n", | |
" \n", | |
"def policy_iteration(A, B, K, S, R):\n", | |
" \n", | |
" A_clp = A + B @ K\n", | |
" n, d = B.shape\n", | |
" \n", | |
" assert spectral_radius(A_clp) < 1\n", | |
" \n", | |
" # first, solve for the value function\n", | |
" # V = dlyap(A+BK, S + K^T R K)\n", | |
" V = dlyap(A_clp, S + K.T @ R @ K)\n", | |
" \n", | |
" # form the Q matrix:\n", | |
" # Q = diag(S, R) + [A^T ; B^T] V [A B]\n", | |
" Q = scipy.linalg.block_diag(S, R) + np.vstack((A.T, B.T)) @ V @ np.hstack((A, B))\n", | |
" \n", | |
" # K_next = -Q_{22}^{-1} Q_{12}^T\n", | |
" return -scipy.linalg.solve(Q[n:, n:], Q[:n, n:].T, sym_pos=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for _ in range(1000):\n", | |
" A, B = random_contractive_system(n, d, rng)\n", | |
" assert operator_norm(A) < 1\n", | |
" K = np.zeros((d, n))\n", | |
" S = np.eye(n)\n", | |
" R = np.eye(d)\n", | |
" K_next = policy_iteration(A, B, K, S, R)\n", | |
" A_clp_next = A + B @ K_next\n", | |
" assert spectral_radius(A_clp_next) < 1\n", | |
" if operator_norm(A_clp_next) > 1:\n", | |
" print(\"counter example found\")\n", | |
" print(\"spectral_radius\", spectral_radius(A_clp_next), \"operator_norm\", operator_norm(A_clp_next))\n", | |
" break" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment