Skip to content

Instantly share code, notes, and snippets.

@aseyboldt
Created January 26, 2023 19:15
Show Gist options
  • Save aseyboldt/a81daa87f14b6adf1348307d5e6f314c to your computer and use it in GitHub Desktop.
Save aseyboldt/a81daa87f14b6adf1348307d5e6f314c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "bc872a0f-c564-4ec4-956d-17f6470ce145",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ef579b2-463b-435d-9350-5b4cf8850022",
"metadata": {},
"outputs": [],
"source": [
"n_dim = 100\n",
"n_draws = 500\n",
"\n",
"# If mu is very far away from the initial point\n",
"# using the identity might be better...\n",
"# mu = np.random.randn(n_draws, n_dim) * 100\n",
"\n",
"mu = np.random.randn(n_draws, n_dim)\n",
"sd = np.exp(np.random.randn(n_draws, n_dim))\n",
"\n",
"init_points = np.random.randn(n_draws, n_dim)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b530a7c-b038-4f91-bcb6-0f8990dcb79d",
"metadata": {},
"outputs": [],
"source": [
"grads = -(init_points - mu) / sd**2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61f104ba-46cb-4428-8c0d-02daab88dabf",
"metadata": {},
"outputs": [],
"source": [
"def dist(estimate, true):\n",
" \"\"\"Log condition number of diagonal matrix.\"\"\"\n",
" eigs = true / estimate\n",
" return np.linalg.norm(np.log(eigs), axis=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39242b0f-b90b-4d5c-8c81-66daac89c7c1",
"metadata": {},
"outputs": [],
"source": [
"# Using initial identity\n",
"dists_identity = dist(sd*sd, np.ones(n_dim))\n",
"dists_grad_sq = dist(sd*sd, 1 / (grads * grads))\n",
"dists_grad = dist(sd*sd, 1 / np.abs(grads))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44064aff-4eb8-422a-bd89-c4bbdce5d87d",
"metadata": {},
"outputs": [],
"source": [
"sns.kdeplot(dists_identity, label=\"identity\")\n",
"sns.kdeplot(dists_grad, label=\"grad\")\n",
"sns.kdeplot(dists_grad_sq, label=\"grad_sq\")\n",
"plt.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34afff0f-f004-4b88-94af-e9cc776082ba",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e5d8570-4368-4ebf-b421-632cb6c048dd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc-dev",
"language": "python",
"name": "pymc-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment