Created
October 1, 2019 02:07
-
-
Save stephentu/15b7aeb62174905d882cf012661d5e20 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import scipy\n", | |
"import scipy.linalg" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rng = np.random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n, d = 3, 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def spectral_radius(M):\n", | |
" return max(np.abs(np.linalg.eigvals(M)))\n", | |
"\n", | |
"\n", | |
"def operator_norm(M):\n", | |
" return np.linalg.norm(M, ord=2)\n", | |
"\n", | |
"\n", | |
"def random_contractive_system(n, d, rng):\n", | |
" B = rng.normal(size=(n, d))\n", | |
" \n", | |
" A = rng.normal(size=(n, n))\n", | |
" A /= operator_norm(A)\n", | |
" return 0.99 * A, B\n", | |
" \n", | |
"\n", | |
"def dlyap(A, Q):\n", | |
" # newer versions of scipy solve A P A^T - P + Q = 0,\n", | |
" # while older ones solve A^T P A - P + Q = 0. I do not\n", | |
" # remember exactly which version of scipy made the change.\n", | |
"\n", | |
" # I am going to assume you have the newer version installed.\n", | |
" # If the assertion below fails, please add an if statement\n", | |
" # that branches on your version.\n", | |
"\n", | |
" P = scipy.linalg.solve_discrete_lyapunov(A.T, Q)\n", | |
"\n", | |
" assert np.allclose(A.T.dot(P).dot(A) - P, -Q)\n", | |
"\n", | |
" return P\n", | |
" \n", | |
" \n", | |
"def policy_iteration(A, B, K, S, R):\n", | |
" \n", | |
" A_clp = A + B @ K\n", | |
" n, d = B.shape\n", | |
" \n", | |
" assert spectral_radius(A_clp) < 1\n", | |
" \n", | |
" # first, solve for the value function\n", | |
" # V = dlyap(A+BK, S + K^T R K)\n", | |
" V = dlyap(A_clp, S + K.T @ R @ K)\n", | |
" \n", | |
" # form the Q matrix:\n", | |
" # Q = diag(S, R) + [A^T ; B^T] V [A B]\n", | |
" Q = scipy.linalg.block_diag(S, R) + np.vstack((A.T, B.T)) @ V @ np.hstack((A, B))\n", | |
" \n", | |
" # K_next = -Q_{22}^{-1} Q_{12}^T\n", | |
" return -scipy.linalg.solve(Q[n:, n:], Q[:n, n:].T, sym_pos=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"counter example found\n", | |
"spectral_radius 0.2838418915724951 operator_norm 1.765756341273163\n" | |
] | |
} | |
], | |
"source": [ | |
"for _ in range(1000):\n", | |
" A, B = random_contractive_system(n, d, rng)\n", | |
" assert operator_norm(A) < 1\n", | |
" K = np.zeros((d, n))\n", | |
" S = np.eye(n)\n", | |
" R = np.eye(d)\n", | |
" K_next = policy_iteration(A, B, K, S, R)\n", | |
" A_clp_next = A + B @ K_next\n", | |
" assert spectral_radius(A_clp_next) < 1\n", | |
" if operator_norm(A_clp_next) > 1:\n", | |
" print(\"counter example found\")\n", | |
" print(\"spectral_radius\", spectral_radius(A_clp_next), \"operator_norm\", operator_norm(A_clp_next))\n", | |
" break" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment