Skip to content

Instantly share code, notes, and snippets.

@aseyboldt
Created August 19, 2022 21:10
Show Gist options
  • Save aseyboldt/d3fcb30178ef94b42d8d0df0e9d5fe3e to your computer and use it in GitHub Desktop.
Save aseyboldt/d3fcb30178ef94b42d8d0df0e9d5fe3e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c9a75c9f-f473-4746-b868-098da478e91f",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy import linalg, special, stats\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "072271ca-a114-44b3-9fb8-c9a758618426",
"metadata": {},
"outputs": [],
"source": [
"N = 1000\n",
"k = 50\n",
"\n",
"def make_cov(N, eigdist, diagdist):\n",
" eigs = eigdist.rvs(N)\n",
" diag = diagdist.rvs(N)\n",
" vecs = stats.ortho_group(N).rvs()\n",
" return np.diag(diag) @ (vecs @ np.diag(eigs) @ vecs.T) @ np.diag(diag)\n",
"\n",
"points = np.random.randn(k, N)\n",
"grads = np.random.randn(k, N)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "09842717-b41d-40ae-90a1-3161be93c677",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(43)\n",
"cov = make_cov(N, stats.lognorm(s=2, scale=1), stats.lognorm(s=1, scale=1))\n",
"#sns.clustermap(cov, center=0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2c4364d3-c82a-44cb-ab21-36ae72f56a31",
"metadata": {},
"outputs": [],
"source": [
"def draw_value_grads(cov, k):\n",
" draws = np.random.multivariate_normal(np.zeros(len(cov)), cov, size=k)\n",
" grads = - linalg.solve(cov, draws.T, assume_a=\"pos\")\n",
" return draws, grads.T"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "41f9aeeb-b148-4ab2-a8d7-b0a6f56baecb",
"metadata": {},
"outputs": [],
"source": [
"points, grads = draw_value_grads(cov, k)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "af6ccd3c-e9d6-4013-b96f-58c5596eb2f4",
"metadata": {},
"outputs": [],
"source": [
"#plt.plot(np.log(linalg.eigvalsh(np.cov(grads.T))))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3dba24ef-eb98-4afe-ac2a-043790f4dccc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f7d6069d7c0>]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.log(linalg.eigvalsh(cov)))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "061618c7-9dd5-404c-a8da-b8f419be4a0c",
"metadata": {},
"outputs": [],
"source": [
"points -= points.mean(0, keepdims=True)\n",
"grads -= grads.mean(0, keepdims=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2417b107-ec82-4149-b758-ab72763e7bac",
"metadata": {},
"outputs": [],
"source": [
"span = np.concatenate([grads, points]).T\n",
"subspace, svdvals, _ = linalg.svd(span, full_matrices=False)\n",
"subspace = subspace.T\n",
"\n",
"# Remove zero values from subtracting mean\n",
"subspace = subspace[:-2, :]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f73f83aa-1dcc-49bc-9c50-80235a0062fc",
"metadata": {},
"outputs": [],
"source": [
"points_sub = points @ subspace.T\n",
"grads_sub = grads @ subspace.T"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5d556afa-3035-41bb-8fbf-ccdfec43c399",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_140260/1413332209.py:1: RuntimeWarning: invalid value encountered in log\n",
" plt.plot(np.log(linalg.eigvalsh(np.cov(points_sub.T))[11:]))\n"
]
},
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f7d605ac130>]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.log(linalg.eigvalsh(np.cov(points_sub.T))[11:]))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ceadacb7-6493-41fd-b16a-a04dd8fb3243",
"metadata": {},
"outputs": [],
"source": [
"Up, sp, _ = linalg.svd(points_sub.T, full_matrices=False)\n",
"Ug, sg, _ = linalg.svd(grads_sub.T, full_matrices=False)\n",
"\n",
"# remove one value due to zero eigenvalue from mean subtraction\n",
"Up = Up[:, :-1]\n",
"sp = sp[:-1]\n",
"Ug = Ug[:, :-1]\n",
"sg = sg[:-1]\n",
"\n",
"sp /= np.sqrt(k)\n",
"sg /= np.sqrt(k)\n",
"sp = sp ** 2\n",
"sg = sg ** 2\n",
"\n",
"others_p = np.exp(np.log(sp.min()) - np.log(sg.min()))\n",
"others_g = 1 / others_p\n",
"\n",
"#others_p = others_g = 1"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2b94688a-dea1-4279-9928-ede55c2f552c",
"metadata": {},
"outputs": [],
"source": [
"epsilon = 1e-6"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d1ca02e2-f19d-45fb-84bb-a81fe116c40a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(98, 49)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Up.shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "48bc4a7e-fcbd-4784-be5f-d731d2d93a0f",
"metadata": {},
"outputs": [],
"source": [
"Cp = (\n",
" Up @ (np.diag(sp) - epsilon * others_p * np.eye(k - 1)) @ Up.T\n",
" + epsilon * others_p * np.eye(2 * k - 2)\n",
")\n",
"Cg = (\n",
" Ug @ (np.diag(sg) - epsilon * others_g * np.eye(k - 1)) @ Ug.T\n",
" + epsilon * others_g * np.eye(2 * k - 2)\n",
")\n",
"\n",
"Cg_sqrt = (\n",
" Ug @ (np.diag(np.sqrt(sg)) - np.sqrt(epsilon * others_g) * np.eye(k - 1)) @ Ug.T\n",
" + np.sqrt(epsilon * others_g) * np.eye(2 * k - 2)\n",
")\n",
"Cg_invsqrt = (\n",
" Ug @ (np.diag(1 / np.sqrt(sg)) - 1 / np.sqrt(epsilon * others_g) * np.eye(k - 1)) @ Ug.T\n",
" + 1 / np.sqrt(epsilon * others_g) * np.eye(2 * k - 2)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9993a64d-a809-4726-a3cf-3d6d70942805",
"metadata": {},
"outputs": [],
"source": [
"#assert np.allclose(Cg_sqrt, linalg.sqrtm(Cg))\n",
"#assert np.allclose(Cg_invsqrt, linalg.inv(linalg.sqrtm(Cg)))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c0ed684e-306c-460f-8c32-612b201b2132",
"metadata": {},
"outputs": [],
"source": [
"mean = Cg_invsqrt @ linalg.sqrtm(Cg_sqrt @ Cp @ Cg_sqrt) @ Cg_invsqrt"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "02f90fe0-6428-4cd1-86b1-021b1bfcfb72",
"metadata": {},
"outputs": [],
"source": [
"#linalg.eigvalsh(mean)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3304b761-1053-422a-b397-da9d1854407d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f7d60524550>]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.log(linalg.eigvalsh(mean)))\n",
"plt.plot(-np.log(sg))\n",
"plt.plot(np.log(sp)[::-1])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "9e494c5e-7bcb-4c3c-a689-92a32a85e2fd",
"metadata": {},
"outputs": [],
"source": [
"S, U = linalg.eigh(mean)\n",
"U = subspace.T @ U"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "d6b33960-cecb-4ec8-b606-b3131ff5de1b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 98)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"U.shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "53caf398-c22f-4528-ba60-9359ca64fe4c",
"metadata": {},
"outputs": [],
"source": [
"full = U @ (np.diag(S) - np.eye(len(mean))) @ U.T + np.eye(N)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "59b69485-3438-4c63-86fb-ea6e8ca22f45",
"metadata": {},
"outputs": [],
"source": [
"#linalg.eigvalsh(full)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0bf4f469-0430-48c9-a0b3-13866a45b1b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f7d605055b0>]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.log(linalg.eigvalsh(cov, full)))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "4050720e-69df-47fe-ab07-7d1a6046168b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7319.017302693897"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Distance to true covaritance matrix\n",
"(np.log(linalg.eigvalsh(full, cov)) ** 2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c02974ec-1f60-4d6c-a062-f53ebb252e78",
"metadata": {},
"outputs": [],
"source": [
"#(np.log(linalg.eigvalsh(full_max, cov)) ** 2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "cb5582dd-efb3-471d-9508-4089a8d4461c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7140.027986303995"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Distance of perfect draws diag approx\n",
"(np.log(linalg.eigvalsh(np.diag(points.var(0)), cov)) ** 2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "ca3c3cbf-e17d-4787-b983-20ce12c234ac",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3951.497131807437"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Diag approx using draws and grads\n",
"(np.log(linalg.eigvalsh(np.diag(np.sqrt(points.var(0) / grads.var(0))), cov)) ** 2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "c547d39c-f315-4383-bff2-b9c0f88b7d95",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/adr/mambaforge/envs/pymc-dev/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2022-08-19 16:08:35.041350: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
"2022-08-19 16:08:35.041392: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n",
"/home/adr/mambaforge/envs/pymc-dev/lib/python3.9/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dev is an invalid version and will not be supported in a future release\n",
" warnings.warn(\n"
]
}
],
"source": [
"import covadapt\n",
"import covadapt.spd_manifold"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "c3e44ddf-a48d-4a8d-b747-708008ae865f",
"metadata": {},
"outputs": [],
"source": [
"#solver = covadapt.spd_manifold.ManifoldSolver()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "59a1eb89-7811-4574-9d21-42f974d5620c",
"metadata": {},
"outputs": [],
"source": [
"#estimator = covadapt.spd_manifold.CovarianceEstimator(solver.estimator, compute_full_matrix=True)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "6aa516bb-8299-4382-af6b-638d90ca5a7d",
"metadata": {},
"outputs": [],
"source": [
"optimizer = covadapt.spd_manifold.PerturbingConjugateGradient(\n",
" min_step_size=1e-6,\n",
" min_gradient_norm=1e-2,\n",
" perturbance_cooldown=20,\n",
" perturbance=0.01,\n",
" #max_iterations=self._maxiter,\n",
" perturbance_stepsize_threshold=1e-4,\n",
" # beta_type=pymanopt.solvers.conjugate_gradient.BetaTypes.PolakRibiere,\n",
" #verbosity=2,\n",
")\n",
"\n",
"solver = covadapt.spd_manifold.ManifoldSolver(\n",
" #n_eigs=self._n_eigs,\n",
" n_eigs=20,\n",
" stiefel_retraction=True,\n",
" exp_stiefel=True,\n",
" optimizer_args={\n",
" \"verbosity\": 2,\n",
" \"min_step_size\": 1e-8,\n",
" \"min_gradient_norm\": 1e-2,\n",
" #\"max_iterations\": self._maxiter,\n",
" },\n",
" solver_method=optimizer,\n",
" #alpha=self._alpha,\n",
" #delta=self._delta,\n",
" #beta=self._beta,\n",
" #gamma=self._gamma,\n",
" gamma=10,\n",
")\n",
"\n",
"estimator = covadapt.spd_manifold.CovarianceEstimator(solver, compute_full_matrix=True)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "735806fd-b5a1-46f8-aab9-cfb12e190e16",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 24s, sys: 1.1 s, total: 1min 25s\n",
"Wall time: 21.8 s\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>CovarianceEstimator(compute_full_matrix=True,\n",
" estimator=&lt;covadapt.spd_manifold.ManifoldSolver object at 0x7f7cf31be370&gt;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">CovarianceEstimator</label><div class=\"sk-toggleable__content\"><pre>CovarianceEstimator(compute_full_matrix=True,\n",
" estimator=&lt;covadapt.spd_manifold.ManifoldSolver object at 0x7f7cf31be370&gt;)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: ManifoldSolver</label><div class=\"sk-toggleable__content\"><pre>&lt;covadapt.spd_manifold.ManifoldSolver object at 0x7f7cf31be370&gt;</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">ManifoldSolver</label><div class=\"sk-toggleable__content\"><pre>&lt;covadapt.spd_manifold.ManifoldSolver object at 0x7f7cf31be370&gt;</pre></div></div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"CovarianceEstimator(compute_full_matrix=True,\n",
" estimator=<covadapt.spd_manifold.ManifoldSolver object at 0x7f7cf31be370>)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"estimator.fit(points, grads)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "94a31b96-d8a7-421c-8b87-0c19d66f530d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3755.4621520250294"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.log(linalg.eigvalsh(estimator.covariance_, cov)) ** 2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "e44c9d2b-7c43-4eec-b76b-6dde47d8fc28",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 4.16194474e+01, -1.01955801e+00, 9.22741697e+00, ...,\n",
" 2.24696907e+00, 2.65851672e+00, -4.15119837e+00],\n",
" [-1.01955801e+00, 9.44641100e+00, 2.93655165e+00, ...,\n",
" 3.68409027e-01, 1.73689467e+00, -2.56409377e-02],\n",
" [ 9.22741697e+00, 2.93655165e+00, 4.98757181e+01, ...,\n",
" 2.49232691e+00, 3.76440338e+00, 6.31676795e-02],\n",
" ...,\n",
" [ 2.24696907e+00, 3.68409027e-01, 2.49232691e+00, ...,\n",
" 4.03760086e+00, 8.59499682e-01, -1.44358900e-01],\n",
" [ 2.65851672e+00, 1.73689467e+00, 3.76440338e+00, ...,\n",
" 8.59499682e-01, 7.07855407e+00, -5.33890328e-01],\n",
" [-4.15119837e+00, -2.56409377e-02, 6.31676795e-02, ...,\n",
" -1.44358900e-01, -5.33890328e-01, 1.67265877e+01]])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cov"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5604fe8-0c7e-4a27-b509-de4462328194",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc",
"language": "python",
"name": "pymc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment