Skip to content

Instantly share code, notes, and snippets.

@mariogeiger
Created January 19, 2023 23:55
Show Gist options
  • Save mariogeiger/8a56b4fff3eb7a2cc483529347f82d5b to your computer and use it in GitHub Desktop.
Save mariogeiger/8a56b4fff3eb7a2cc483529347f82d5b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"def gram_schmidt(A: np.ndarray, *, epsilon=1e-4) -> np.ndarray:\n",
" \"\"\"Orthogonalize a matrix using the Gram-Schmidt process.\n",
"\n",
" Args:\n",
" A (np.ndarray): Matrix of shape (m, n) with m <= n.\n",
" epsilon (float): Tolerance for rank detection.\n",
"\n",
" Returns:\n",
" np.ndarray: Matrix of shape (m, n) with orthonormal rows.\n",
" \"\"\"\n",
" assert A.ndim == 2, \"Gram-Schmidt process only works for matrices.\"\n",
" assert A.dtype in [np.float64, np.complex128], \"Gram-Schmidt process only works for float64 matrices.\"\n",
"\n",
" Q = []\n",
" P = np.zeros((A.shape[1], A.shape[1]), dtype=A.dtype)\n",
" for v in A:\n",
" v = v - P @ v\n",
" norm = np.linalg.norm(v)\n",
" if norm > epsilon:\n",
" v = v / norm\n",
" P += np.outer(v, np.conj(v))\n",
" Q += [v]\n",
"\n",
" Q = np.stack(Q) if len(Q) > 0 else np.empty((0, A.shape[1]))\n",
" return Q\n",
"\n",
"\n",
"def extend_basis(A: np.ndarray, *, epsilon=1e-4, returns=\"Q\") -> np.ndarray:\n",
" \"\"\"Add rows to A to make it full rank.\n",
"\n",
" Args:\n",
" A (np.ndarray): Matrix of shape (m, n) with m <= n.\n",
" epsilon (float): Tolerance for rank detection.\n",
" returns (str): What to return. Can be \"Q\" or \"E\".\n",
" \"Q\" returns the complete orthogonal basis.\n",
" \"E\" returns the matrix that extends A to a full rank matrix.\n",
"\n",
" Returns:\n",
" np.ndarray: Matrix of shape (n, n) (if returns=Q) or (n - m, n) (if returns=E).\n",
" \"\"\"\n",
" Q = gram_schmidt(A, epsilon=epsilon)\n",
" \n",
" E = []\n",
" P = Q.T @ np.conj(Q)\n",
" for v in np.eye(A.shape[1], dtype=A.dtype):\n",
" v = v - P @ v\n",
" norm = np.linalg.norm(v)\n",
" if norm > epsilon:\n",
" v = v / norm\n",
" P += np.outer(v, np.conj(v))\n",
" E += [v]\n",
"\n",
" E = np.stack(E) if len(E) > 0 else np.empty((0, A.shape[1]))\n",
"\n",
" if returns == \"E\":\n",
" return E\n",
" if returns == \"Q\":\n",
" return np.concatenate([Q, E])\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"Q = gram_schmidt(np.array([[1, 1, 1j, 0], [1, 0, 0, 1.0]]))\n",
"\n",
"np.testing.assert_allclose(np.conj(Q) @ Q.T, np.eye(2), atol=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"Q = extend_basis(np.array([[1j+1, 1, 0], [1, 2, 0.0]]), returns=\"Q\")\n",
"\n",
"np.testing.assert_allclose(np.conj(Q) @ Q.T, np.eye(3), atol=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "f26faf9d33dc8b83cd077f62f5d9010e5bc51611e479f12b96223e2da63ba699"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment