Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tam17aki/491e676d82b5259bc0fd3bcda57bbf3b to your computer and use it in GitHub Desktop.
Save tam17aki/491e676d82b5259bc0fd3bcda57bbf3b to your computer and use it in GitHub Desktop.
diffusion_model_book_2_2_score_based_model.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tam17aki/491e676d82b5259bc0fd3bcda57bbf3b/diffusion_model_book_2_2_score_based_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ybGLrN9MeBRz"
},
"outputs": [],
"source": [
"from tqdm import tqdm_notebook as tqdm\n",
"import torch\n",
"\n",
"device = \"cpu\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zJNm-Rep08EY"
},
"outputs": [],
"source": [
"n_samples = int(1e6)\n",
"sigma = 0.1\n",
"\n",
"dist0 = torch.distributions.MultivariateNormal(torch.tensor([-2, -2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))\n",
"samples0 = dist0.sample(torch.Size([n_samples//2]))\n",
" \n",
"dist1 = torch.distributions.MultivariateNormal(torch.tensor([2, 2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))\n",
"samples1 = dist1.sample(torch.Size([n_samples//2]))\n",
"samples = torch.vstack((samples0, samples1))\n",
"\n",
"mean = torch.mean(samples, dim=0)\n",
"std = torch.std(samples, dim=0)\n",
"\n",
"normalized_samples = (samples - mean[None, :])/std[None, :]\n",
"\n",
"dataset = torch.utils.data.TensorDataset((normalized_samples))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "f8wwV52v4Cqr",
"outputId": "1e39d255-5461-4bbc-d736-753c54e03f19"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQoAAAEICAYAAACnA7rCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAd90lEQVR4nO3deZRkZ3nf8e+vqrfpWTVah5GMIBbEAmyzGMEBn5AIbNDBCGzjiNhshqODA46dQ2wTk9jAHzE5ccCx4RiUQFjDcrBxhjAYJHAsiC0igcUiDTIDkZgZZC0zmp6lt6ruJ3+87626XV01t0dVU13d+n3OqXPr3nrvUr0897nv+977KiIwMzuT2nofgJmNPgcKM6vkQGFmlRwozKySA4WZVXKgMLNKDhTWk6S3SPrIeh/H2ZD005LuWu/j2GwcKEaQpGdL+htJM5KOSfo/kn5qvY/rbEi6W9KcpJOSjufv8zpJ5/RvLiK+HBGP7ziO557LfT4SOFCMGEk7gP8F/AmwG9gLvBVYWM/jeph+LiK2A48G3g78DvC+9T0kezgcKEbP4wAi4mMRsRQRcxHxhYj4JoCkfyTpS5KOSnpQ0kcl7SpWzmfQ35L0TUmnJb1P0sWSPpfP7jdJOi+XvVxSSLpe0g8l3Svp3/Q6MEnPyJnBcUnfkPSctXyhiJiJiH3APwdeKemJeXuTkv5Q0g8k3SfpPZK25M+eI+mwpDdKuj8f26tLx3KNpDvzdzpSHHexXn7/YeBHgM9IOiXptyV9VtKvd3yvb0p6yVq+yyNWRPg1Qi9gB3AU+CDwAuC8js9/FHgeMAlcCNwM/FHp87uBW4CLSdnI/cDXgScDU8CXgN/PZS8HAvgYsBV4EvAA8Nz8+VuAj+T3e/NxXUM6wTwvz1/Y43vcXWynY/kPgF/L798J7CNlTtuBzwB/kD97DtAE3gaM5/3OFj8P4F7gp/P784CnlNY73Os4gF8Cvlqa/4n8PSbW+3c/yi9nFCMmIk4Azyb9A/9X4AFJ+yRdnD8/GBE3RsRCRDwAvAP4Jx2b+ZOIuC8ijgBfJv1j/F1EzAOfJgWNsrdGxOmI+Bbw34GXdTm0XwH2R8T+iFiOiBuB20j/wGfjh8BuSQKuB/51RByLiJPAfwCuK5VtAG+LiEZE7AdOAY8vfXalpB0R8VBEfH2N+98HPE7SFXn+5cAnImLxLL/HI4oDxQiKiAMR8aqIuBR4IvAo4I8A8mXEx3O6fQL4CHBBxybuK72f6zK/raP8odL7e/L+Oj0aeGm+7Dgu6TgpoO05qy+XMpNjpGxoGvhaaXt/mZcXjkZEszQ/Wzr2XyAFqXsk/bWkZ65l5zlYfgL4lVyx+jLgw2f5HR5xHChGXER8B/gAKWBAOusG8KSI2EE606vP3VxWev8jpLN+p0PAhyNiV+m1NSLevtad5JabvcBXgAdJQesJpe3tjIjOINZVRNwaEdcCFwF/AXyyV9Euyz4I/DJwNTAbEX+71u/wSOVAMWIk/eNcgXdpnr+MdNa7JRfZTkrBZyTtBX5rALv995KmJT0BeDXpjNvpI8DPSfpZSXVJU7ni8NI1fKcdkl4IfJxU5/GtiFgmXVq9U9JFudxeST+7hu1NSPplSTsjogGcAJZ7FL8PeGx5QQ4My8B/xtnEmjhQjJ6TwFXAVyWdJgWIbwNvzJ+/FXgKMAN8FvjzAezzr4GDwBeBP4yIL3QWiIhDwLXA75IqPA+RgtSZ/oY+I+lkLvtmUn3Kq0uf/07e7y35Muom2nUQVV4O3J3Xex0pQ+jmD4B/ly9vyi06HyJV3m6oDmXrRbnm1x6BJF0O/D9gvKMuYNOT9Arg+oh49nofy0bgjMIecSRNA/8SuGG9j2Wj6DtQSLpM0l/lzi93SPqNLmUk6Y8lHcydW57S737NHo5cB/IAqe7if6zz4WwYfV96SNoD7ImIr0vaDnwNeHFE3Fkqcw3w66TmrKuA/xIRV/W1YzMbmr4zioi4t+jskjvNHCA1gZVdC3wokluAXTnAmNkGMDbIjeXKsScDX+34aC8rO/Uczsvu7bKN60k99qhTf+o0OwZ5iGZWcpKHHoyIC6vKDSxQSNoG/Bnwm7kb8sMSETeQK5l2aHdcpasHdIRm1umm+NQ9ayk3kFYPSeOkIPHRiOjWrn+Elb3/Ls3LzGwDGESrh0jPGDgQEe/oUWwf8Irc+vEMYCYiVl12mNloGsSlx7NIveS+Jen2vOx3SfcMEBHvAfaTWjwOkm7sefXqzZjZqOo7UETEV6i4KSlSG+zr+92Xma0P98w0s0oOFGZWyYHCzCo5UJhZJQcKM6vkQGFmlRwozKySA4WZVXKgMLNKDhRmVsmBwswqOVCYWSUHCjOr5EBhZpUcKMyskgOFmVVyoDCzSg4UZlZpUE/hfr+k+yV9u8fnz5E0I+n2/Pq9QezXzIZjUON6fAB4F2ko+V6+HBEvHND+zGyIBhIoIuLmPEqY2calNSTYsXzuj2MEDbOO4pmSviHpc5KeMMT9mlmfBjr26Bl8HXh0RJzKI5v/BXBFt4LlsUenmB7S4ZnZmQwlo4iIExFxKr/fD4xLuqBH2Rsi4mkR8bRxJodxePZIpdqKl2qqftXrqF5ftW7X1yYylG8j6ZI89CCSnp73e3QY+zaz/g3k0kPSx4DnABdIOgz8PjAOrSEFfxH4NUlNYA64Lo8eZjYcpTO8alq17EzLgXYlZv6sGBovlpZWfr5ilVrngrM65FEyqFaPl1V8/i5S86mZbUDDqsw0Wx9FBlDT6mX1epqvF2VyBlAsV5chdVsZRKxYl6V2tlBkGaqlZbG88ZPnzVXjYmbnhDMK21w6Moh21lBvFyneT4yn+bGxVWUAGC/9exRZQZFRNJtp2miunAcospeiLGnaqrPYgHUVzijMrJIzCtuUWllDkWGMtf/UNZX75xTLtkylac4wol60fnSro0iZhRYW07TIGubm22VydhHzC6lMe+W0vLM1JC2s+EbryxmFmVVyRmGbw6qWjDTVxESejrfL5gwitqZbBJa3pwxjeTL9OyxNpXWjS0JRn09ZQX0hbaM2mzOLcsZy6nSa5mOKhZRZtDKMojWkVYcx+pxRmFklBwozq+RLD9vYOi45WtN8KaDJdOnBtq2tVZa3p0uO5q50+dDclsrO70rrLuerlKXx9rWHcl3j2EKqzBw/nRaMn5xYMQWoF5Wgs3NpupgvT3JzaxRNqip1xHJlppltdM4obENb0TUb2t2xJ1ZmErF1qlWkeV56P39+KjN/XlpnLj/4oMgolttJQquNsz6X3kweT9nHlqNa8TmAmiljqeVu3io6Y+Um1VbnrVIWMeqdsZxRmFklZxS28XS5DbyzjqLVeWpLSguaO9sZxeLO9NnsRWk7s5ek5Y2dudlydzrzj0+1u2U3FtK/yuJsmjanc1PqZG4CVftfqdbI9RaNfHNYkUkUmUXX29lHu6nUGYWZVXJGYRtbcVYubvcubuSazN20c4bR3NrucDV7YVo2d1GaX7gknel3XHISgD07TgDwmG3th7DNLaUs4Xsnzgfg8GSaLo+n7dYW2+fcsfm8bCFlCWOnc6ev3PGqdfNZqcNVFG81mnUVzijMrJIzCtvQOls9Oh9Gs7QtZRZLk+1yzVxd0diaWh+2XDALwDP23APAs3Z+F4ALx0601jm+lFpPbq4/DoBa7gNxz2JKSxZn2v9KjeNp35OT+ViKLKdoicldubs9nm9Uu3U7ozCzSsMae1SS/ljSQUnflPSUQezXzIZjWGOPvoA04M8VwFXAn+ap2bkxtvJpVc3p9jmxsSNNl7anCsPHnp8qLR8z/UCaH78fgF319jMm9tZn8rppuzONLQAc2rI7bWuydOkxnS4jlsdz0+l4brotCrSepNVYfdybuTIzIm4Gjp2hyLXAhyK5Bdglac8g9m1m596wKjP3AodK84fzsns7C3pIQetLcUNW7ia9PNbloRJFD+qxdNaeWUy1m7vr6TkSRSZxSb1909b5tfS3eHczVXxuH0sVkhNbUlawXG936CpSh1XPszjTUDatik1XZq6JhxQ0Gz3DyiiOAJeV5i/Ny8zOWrlJtBgzo1brOOfls3etmW/MKp+o84lduZNUcynVGXx3PvXl/smpHwBwKNor3ZmrDO5eTHeO3Te/HYDF2dS5aqL0EO6xuY7ModuzN9PBd38/goaVUewDXpFbP54BzETEqssOMxtNwxp7dD9wDXAQmAVePYj92iNTeeQt1Ytl6YysYsSuZr4hazF3o55rn7EnTqbz48Jcmh47meofDmy5GIB9ejIA0/XFVfv+9slHAXDPzHlpwUPFg2vaZeqL6fjqC3mfjY7RxcpjgKz+cr0/W0fDGns0gNcPYl9mNnzuwm0bW3EGXlr5ZOviYTG1+XzL+Kn2TWETJ1IaMn0k1R3MRuqefaCRlt97KnW0OG/LbGud4/Op38TcQr5F/b5tAEw+VMvbbB/SxMl0LPXTqUWkdZt50T27S53FqI9POnKtHmY2epxR2Maz4jo+j8HRkUm0Hmi7kObrs+1ekFPH85+9iofppjP8fKS+EMdmUr3DMe1q70bFCGHp3Dp1NE23pM6cTN/fbiGZeCjtu34q9+zMGUUUI4gtNvLXGO0soswZhZlVcqAws0q+9LANrUjfezaTnk4VkrWJ9k1ikw/mJ3XnZ1tqKf0bjJ9OlyCNrSvH9wCoFY+7zPWSReXl1EPpkmPqaLsptTaX388X0/z8icWOm8BGtCm0G2cUZlbJGYVtDp3NpPn5lEVDpGZOt4qOLRXdunMzZs4spo6l82ZjW84ouvx3jM3nzlS5A9fE8VxxOTPXKqO5nEHMpWWRK1aLCtfWdNkjhZnZJuKMwja2fCZujbRVdGrqPEOX5pW7UtfzGKC12ZRRxGSqlBg/WTSflm4+y29rizkLmc3Nr7M5e2iW7jo7eSqts9jRLJr3N6rPxTwTZxRmVskZhW0qrbN1kQIsFuN9ljKMoqWkaCHJHaKKEdBrHY/RS9vL9QlLHZlKMVZH6UavzkxiLVnOqHNGYWaVnFHY5tBxdm7VWahjRHEgOm/3LsYtzRkFnSOkQ7u+Iq9b1De050t9JDrKRK+MYgNxRmFmlZxR2Ka2qs4CWnUU7du+82P16yt7Tqr0eL2ix+eqB+QW2yplC6v6SWzgTKLgjMLMKjlQmFmlQQ0p+HxJd+UhA9/U5fNXSXpA0u359dpB7Nesp1he8YrlaL+Wlla+mo30Wkwvmk1oNlmeX2i9imXRyK/FxfTK21huNFuvYj+bSd91FJLqwLuB55EG9rlV0r6IuLOj6Cci4g397s/Mhm8QlZlPBw5GxPcBJH2cNIRgZ6AwWz9dKhRbTaiZaqnMciM3qZbHD2m1cHZmCl2aPkd0/NB+DOLSo9dwgZ1+IY9k/ilJl3X5HEhDCkq6TdJtDRYGcHhm1q9hVWZ+Brg8In4cuBH4YK+CHlLQhqZHPUZrvrMuY2lp1TqtV7ftbiKDCBSVwwVGxNGIKNKD/wY8dQD7NbMhGUSguBW4QtJjJE0A15GGEGyRtKc0+yLgwAD2azZYmzATGJS+KzMjoinpDcDnSc9Of39E3CHpbcBtEbEP+FeSXgQ0gWPAq/rdr5kNj6KzS+oI2aHdcZWuXu/DMNu0bopPfS0inlZVzj0zzaySA4WZVXKgMLNKDhRmVsmBwswqOVCYWSUHCjOr5EBhZpUcKMyskgOFmVVyoDCzSg4UZlbJgcLMKjlQmFklBwozq+RAYWaVHCjMrJIDhZlVGtaQgpOSPpE//6qkywexXzMbjr4DRWlIwRcAVwIvk3RlR7HXAA9FxI8C7wT+Y7/7NbPhGURG0RpSMCIWgWJIwbJraQ/68yngaknCzDaEYQ0p2CoTEU1gBji/28Y8pKDZ6Bm5ykwPKWg2eoYypGC5jKQxYCdwdAD7NrMhGMqQgnn+lfn9LwJfilEeecjMVhjWkILvAz4s6SBpSMHr+t2vmQ1P34ECICL2A/s7lv1e6f088NJB7MvMhm/kKjPNbPQ4UJhZJQcKM6vkQGFmlRwozKySA4WZVXKgMLNKDhRmVsmBwswqOVCYWSUHCjOr5EBhZpUcKMyskgOFmVVyoDCzSg4UZlbJgcLMKjlQmFmlvgKFpN2SbpT03Tw9r0e5JUm351fng3fNbMT1m1G8CfhiRFwBfDHPdzMXET+ZXy/qc59mNmT9BoryUIEfBF7c5/Y2D9V6vlSvr3h1LWc2Qvr9i7w4Iu7N7/8BuLhHuak8TOAtkl58pg16SEGz0VP5uH5JNwGXdPnozeWZiAhJvQb1eXREHJH0WOBLkr4VEd/rVjAibgBuANih3aM7SFCPs75qXcZezmWLz2I5VsyrzorlaUarl6UFD/eIzR62ykAREc/t9Zmk+yTtiYh7Je0B7u+xjSN5+n1J/xt4MtA1UJjZ6On30qM8VOArgf/ZWUDSeZIm8/sLgGcBd/a5XzMbon5HCns78ElJrwHuAX4JQNLTgNdFxGuBHwPeK2mZFJjeHhGbJlCoXl85X1x6lC9NWsvypUarbCoTy+lyQkuly4oo1smXHvmSI5ZrK+bNhqGvQBERR4Gruyy/DXhtfv83wJP62Y+Zra+BjD26aXWpsOysrGxlFPXainlNTXbZXl63WKfIJJpLAMTSUrtss7miTGueXHa5dGzOLuwcc4O9mVVyRrEGK7IIdWQO4/lHODGe5qem0vx46UdbZBC5TqLIPiiaPhspW9ByKaNYWEzT+YWV6y7m5aWW6ChWc2Zh54gzCjOr5Iyim866idJ8Zx2EJifSB1umVkyXt4y31omJ9GNeHs/bKVo/mikDqC3kjKLRzig0t5h3nddp1VFkRWYBpc5ZbhGxc8MZhZlVckZxBu0u1u2+EirqIvKUbVsBiC0ps1jaOQ1AY1v7R9vYnt5HUTUxnrZba8SK6cRMO2sYO522X8t1HTo9mz4oWkaW21lDuwalo0XEmYUNiDMKM6vkjKKs4+atVS0bgCZyncT2bQDEti0ANHemuon589PnC7vaMXhhZ9pec0uxkTxZSm8mTqb55nR7nclj6f1EbiEZ69Z7M4u53AKSs43WzWdLq4qaPSzOKMyskgOFmVXypUc3RXPoWP7xlG/8KppDc2Vmcckxd2FafmpvKtvY1l5l/qJ0abC8JVc2Fq2k+dJj4URasPhgu1qyObnyV6PcOau+VHT7LjWXLjby8aZjimZj5fdwpab1yRmFmVVyRtHFqhu/Jtqdp8iVmctbU4erxo702ek9KZOYuygVW7ywfcbffskpAC7anmotJ+opszg6m5pWHzyW0o/5+lR5rwDUG+lXNDaX1mk1l46VfnXF+yXXXtq54YzCzCo5oyhZlUmM50yidPaOLSmTaG5Nny1uy/UL29Pni+ens/qOPSdb6zz1kkMAPG7rfQBM11L360ML5wPwd5N7Abi7fn5rnYVG6rg1cTId0+LOdAy1xdTGOlbUSwDMzeUD7nxojjMMGwxnFGZWyRkF9B5Ho+gmXSt9PpbeL0/k1o3pdBZfLp5TM5nWefSuh1qr/NSOuwF40mTKLLbmjOKByVQ3saOeMoLTi+2H3dy3NdVXFJ2wliZX3lDWmtLOfCLfkt7KjMIdr2ww+h1S8KWS7pC0nJ+T2avc8yXdJemgpF6jiZnZiOo3o/g28PPAe3sVkFQH3g08DzgM3Cpp34Z4wG7Hg3OB1pk86sWZPU2KvhH16VR3cMHUqdYqW2vpTH/FeLqxa1rpx767lrKO742n+oyxevvUvzyd3i9N5Qwm/6aWx7qMGxLFOCH5Yb3F4s4xQcwepn4frnsAQOryx9v2dOBgRHw/l/04aSjC0Q8UZgYMp45iL3CoNH8YuKpXYUnXA9cDTDF9bo+s50EUPRq7nJG7LQOiI/k40djSer+U0427cnfNx+YM4r6lVA8xs5RvLFsq3c6+mG9Qy0lGLXfLaD0Br166auxyo9gK7qFpfeprSMGIWDXgT782zJCCZo8gfQ0puEZHgMtK85fmZWa2QQzj0uNW4ApJjyEFiOuAfzGE/a5dKyWvr5xf7piW3teK510W92OdTtOF/GSqH57a0Vrl77etTMgW83XKnQupo9Wh+d0AHD2xtVWmfjo/hyL3pao189Ow8vM1aZbaPFuDHxejiXlgYxusfptHXyLpMPBM4LOSPp+XP0rSfoCIaAJvAD4PHAA+GRF39HfYZjZM/bZ6fBr4dJflPwSuKc3vB/b3s69hap2RiydGlZ9Pmc/o9dNpOjafMojx0+msPv5gmv+HiZ2tdb5SeywA92xLmcMFEyn9+MHseQD88GTKPhpH2zeFTRbbO5lvL5/Lt5cXT+putLtwt0YY61HRatYvd+E2s0ruwl1SZBIqWimLZsdyfUAeT6N+OnWimjqWfoRLxe3nuUPU/FI7OzjcTDd7HRlPGcTYRNpeYz7/+E+n6cSD7ebR6XT/GFPH0zGMn0oZhOZzHUUxghi0RhprjTzmOgkbMGcUZlbJGUU3xRm5uKmqNCqXilHL863nYzOpTmK6dYt6Ho+j0e6t2jiZso1m8Xi8fOLP948xfiLPz7QPYTJnEpPHGnk/82m/J1LX8Cg/Ci/XTRR1FeHMwgbMGYWZVXJGURZFP4Tilu58Zu7yIFvlOoL6TB77o7gxK/d3GD/d/tEubk/ba07lLKN4du9sXief+Cdn2nUhEzMdmcTJ3KGiaO1otI8pGh3jkpoNmDMKM6vkjOIMotvDaos7ZYvRxPJ8Lfe1mMj9HMZm2w/kncj1GMWYo60bu/K0yEbGj7dbMjSfRzM/lTOJuZRZxEJaHovd+lEsr5yaDYgzCjOr5EBhZpV86bEW5VQ+V2zGfJ7vGLmrmNbnJlqr1IqKzeLZmx3P4tRiroycm2+t0+rklStNiybaVhNo+bIoetwMZjYgzijMrJIzim6KDCI/GWrlmTqfyVs3jnXcip6bL1V63qYm2tkF0H46VZE1FNlB6eazorIyiubQvJ9VFZerjs9s8JxRmFklZxRn0qWZseiMpVpxhs8fFL28i2bL8dKPtqh7KLKMoom1uC18uSNbKH3W6kx1pnoIN4faOeaMwswqOaM4W53dvOmoM+j2xOtiWa+u1mfIFlZ1+nL2YOvAGYWZVeoro5D0UuAtwI8BT4+I23qUuxs4STr9NiOi5/CDG0bHmb0zwyiP99kaC7Qz2+iYP2PrhTMJW0fnfEjBkn8aEQ/2uT8zWwfDGFLQzDa4YVVmBvAFSQG8N48Gtrl0XhqoVvqoeOJUY2WZoiLUQ/7ZiBvWkILPjogjki4CbpT0nYi4ucf+1n/sUTNbYRhDChIRR/L0fkmfJo1w3jVQbJqxR88mO3AmYSPunDePStoqaXvxHvgZUiWomW0Q53xIQeBi4CuSvgH8X+CzEfGX/ezXzIbrnA8pGBHfB36in/2Y2fpyz0wzq+RAYWaVHCjMrJIDhZlVcqAws0oOFGZWyYHCzCo5UJhZJQcKM6vkQGFmlRwozKySA4WZVXKgMLNKDhRmVsmBwswqOVCYWSUHCjOr5EBhZpUcKMysUr8P1/1Pkr4j6ZuSPi1pV49yz5d0l6SDkt7Uzz7NbPj6zShuBJ4YET8O/D3wbzsLSKoD7wZeAFwJvEzSlX3u18yGqK9AERFfiIhmnr0FuLRLsacDByPi+xGxCHwcuLaf/ZrZcA1y7NFfBT7RZfle4FBp/jBwVa+NlIcUBBZuik9txsGCLgA248jum/V7web9bo9fS6GBjD0q6c1AE/jo2RxhN+UhBSXdFhFP63ebo8bfa+PZrN9N0m1rKdf32KOSXgW8ELg6IrqNFXoEuKw0f2leZmYbRL+tHs8Hfht4UUTM9ih2K3CFpMdImgCuA/b1s18zG65+Wz3eBWwHbpR0u6T3wMqxR3Nl5xuAzwMHgE9GxB1r3P4NfR7fqPL32ng263db0/dS96sFM7M298w0s0oOFGZWaaQDxVq7iG9Ekl4q6Q5Jy5I2fLPbZu2mL+n9ku6XtKn680i6TNJfSboz/x3+xpnKj3SgYA1dxDewbwM/D9y83gfSr03eTf8DwPPX+yDOgSbwxoi4EngG8Poz/c5GOlCssYv4hhQRByLirvU+jgHZtN30I+Jm4Nh6H8egRcS9EfH1/P4kqUVyb6/yIx0oOvwq8Ln1Pgjrqls3/Z5/dDZaJF0OPBn4aq8yg7zX42EZdhfxYVrLdzNbT5K2AX8G/GZEnOhVbt0DxQC6iI+squ+2ibib/gYkaZwUJD4aEX9+prIjfemxxi7itv7cTX+DkSTgfcCBiHhHVfmRDhT06CK+GUh6iaTDwDOBz0r6/Hof08PVZzf9kSbpY8DfAo+XdFjSa9b7mAbkWcDLgX+W/7dul3RNr8Luwm1mlUY9ozCzEeBAYWaVHCjMrJIDhZlVcqAws0oOFGZWyYHCzCr9f080HOw3wFo2AAAAAElFTkSuQmCC\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plot_samples = normalized_samples.cpu().numpy()\n",
"\n",
"plt.hist2d(plot_samples[:,0], plot_samples[:,1], range=((-2, 2), (-2, 2)), cmap='viridis', rasterized=False, bins=100, density=True)\n",
"plt.gca().set_aspect('equal', adjustable='box')\n",
"plt.xlim([-2, 2])\n",
"plt.ylim([-2, 2])\n",
"plt.title('Sample Density')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0gDBxSZdq8c"
},
"outputs": [],
"source": [
"sigma_begin = 0.001\n",
"sigma_end = 1.0\n",
"T = 200\n",
"sigmas = torch.tensor(np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), T))).float().to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "maXGG_Mdefn_"
},
"outputs": [],
"source": [
"def dsm_loss(score_model, samples, sigmas):\n",
" t = torch.randint(0, len(sigmas), (samples.shape[0],), device=sigmas.device)\n",
" used_sigmas = sigmas[t].view(samples.shape[0], *([1] * len(samples.shape[1:])))\n",
" noise = torch.randn_like(samples) * used_sigmas\n",
" perturbed_samples = samples + noise\n",
" target = - 1 / (used_sigmas ** 2) * noise\n",
" scores = score_model(perturbed_samples, used_sigmas)\n",
" target = target.view(target.shape[0], -1)\n",
" scores = scores.view(scores.shape[0], -1)\n",
" w = used_sigmas.squeeze(-1) ** 2\n",
" loss = ((scores - target) ** 2).sum(dim=-1) * w\n",
" return loss.mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-Lwg4kMVM03Z"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class ScoreModel(nn.Module):\n",
" def __init__(self, n_channels=2):\n",
" super(ScoreModel, self).__init__()\n",
"\n",
" self.model = nn.Sequential(\n",
" nn.Linear(n_channels + 1, 2*n_channels),\n",
" nn.ELU(),\n",
" nn.Linear(2*n_channels, 16*n_channels),\n",
" nn.ELU(),\n",
" nn.Linear(16*n_channels, 2*n_channels),\n",
" nn.ELU(),\n",
" nn.Linear(2*n_channels, n_channels),\n",
" )\n",
"\n",
" def forward(self, x, sigma):\n",
" x = torch.cat((x, sigma), dim=1)\n",
" y = self.model(x)\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P9QWytgpa0pn",
"outputId": "d6a8f58d-c1c5-41f4-b5e4-173175143409"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 steps loss:2.10343599319458\n",
"1000 steps loss:2.138179063796997\n",
"2000 steps loss:1.8900189399719238\n",
"3000 steps loss:1.9138953685760498\n",
"4000 steps loss:1.9245710372924805\n",
"5000 steps loss:1.914085865020752\n",
"6000 steps loss:1.8996541500091553\n",
"7000 steps loss:1.831539273262024\n",
"8000 steps loss:1.8835861682891846\n",
"9000 steps loss:1.809353232383728\n",
"10000 steps loss:1.6475526094436646\n",
"11000 steps loss:1.6848489046096802\n",
"12000 steps loss:1.5485844612121582\n",
"13000 steps loss:1.7400761842727661\n",
"14000 steps loss:1.601453423500061\n",
"15000 steps loss:1.6501575708389282\n",
"16000 steps loss:1.5780878067016602\n",
"17000 steps loss:1.6547218561172485\n",
"18000 steps loss:1.5332990884780884\n",
"19000 steps loss:1.5731066465377808\n",
"20000 steps loss:1.6006646156311035\n",
"21000 steps loss:1.5818557739257812\n",
"22000 steps loss:1.4299019575119019\n",
"23000 steps loss:1.6383044719696045\n",
"24000 steps loss:1.6676623821258545\n",
"25000 steps loss:1.5119620561599731\n",
"26000 steps loss:1.522580623626709\n",
"27000 steps loss:1.462673306465149\n",
"28000 steps loss:1.533888578414917\n",
"29000 steps loss:1.4086380004882812\n",
"30000 steps loss:1.5557557344436646\n",
"31000 steps loss:1.4251868724822998\n",
"32000 steps loss:1.4696316719055176\n",
"33000 steps loss:1.461082100868225\n",
"34000 steps loss:1.487654209136963\n",
"35000 steps loss:1.4957597255706787\n",
"36000 steps loss:1.5456204414367676\n",
"37000 steps loss:1.7313734292984009\n",
"38000 steps loss:1.5402085781097412\n",
"39000 steps loss:1.450195550918579\n",
"40000 steps loss:1.4396100044250488\n",
"41000 steps loss:1.516808271408081\n",
"42000 steps loss:1.5489423274993896\n",
"43000 steps loss:1.4515758752822876\n",
"44000 steps loss:1.541172981262207\n",
"45000 steps loss:1.467519998550415\n",
"46000 steps loss:1.4105401039123535\n",
"47000 steps loss:1.4251112937927246\n",
"48000 steps loss:1.3971335887908936\n",
"49000 steps loss:1.5265257358551025\n",
"50000 steps loss:1.4675631523132324\n",
"51000 steps loss:1.520257830619812\n",
"52000 steps loss:1.3204165697097778\n",
"53000 steps loss:1.5094389915466309\n",
"54000 steps loss:1.4009146690368652\n",
"55000 steps loss:1.502292275428772\n",
"56000 steps loss:1.6551986932754517\n",
"57000 steps loss:1.566804051399231\n",
"58000 steps loss:1.7013672590255737\n",
"59000 steps loss:1.5082180500030518\n",
"60000 steps loss:1.4296739101409912\n",
"61000 steps loss:1.5015251636505127\n",
"62000 steps loss:1.5544426441192627\n",
"63000 steps loss:1.4355849027633667\n",
"64000 steps loss:1.4423151016235352\n",
"65000 steps loss:1.5445194244384766\n",
"66000 steps loss:1.4904658794403076\n",
"67000 steps loss:1.4084433317184448\n",
"68000 steps loss:1.502392053604126\n",
"69000 steps loss:1.458216905593872\n",
"70000 steps loss:1.5076521635055542\n",
"71000 steps loss:1.5410008430480957\n",
"72000 steps loss:1.5718374252319336\n",
"73000 steps loss:1.526256799697876\n",
"74000 steps loss:1.5020813941955566\n",
"75000 steps loss:1.5269545316696167\n",
"76000 steps loss:1.387231707572937\n",
"77000 steps loss:1.4536584615707397\n",
"78000 steps loss:1.4981789588928223\n",
"79000 steps loss:1.421187162399292\n",
"80000 steps loss:1.4645973443984985\n",
"81000 steps loss:1.4400806427001953\n",
"82000 steps loss:1.4696435928344727\n",
"83000 steps loss:1.6755788326263428\n",
"84000 steps loss:1.4436883926391602\n",
"85000 steps loss:1.4473791122436523\n",
"86000 steps loss:1.4851280450820923\n",
"87000 steps loss:1.5375834703445435\n",
"88000 steps loss:1.4770243167877197\n",
"89000 steps loss:1.5834132432937622\n",
"90000 steps loss:1.5765262842178345\n",
"91000 steps loss:1.5022810697555542\n",
"92000 steps loss:1.548335075378418\n",
"93000 steps loss:1.641385555267334\n",
"94000 steps loss:1.4702154397964478\n",
"95000 steps loss:1.5113106966018677\n",
"96000 steps loss:1.4681692123413086\n",
"97000 steps loss:1.568758487701416\n",
"98000 steps loss:1.4290732145309448\n",
"99000 steps loss:1.5507277250289917\n"
]
}
],
"source": [
"import torch\n",
"\n",
"batch_size = 512\n",
"n_steps = 100000\n",
"\n",
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)\n",
"dataloader_iter = iter(dataloader)\n",
"\n",
"score_model = ScoreModel().to(device)\n",
"\n",
"optimizer = torch.optim.Adam(score_model.parameters())\n",
"lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=n_steps)\n",
"\n",
"\n",
"\n",
"for i in range(n_steps):\n",
" try:\n",
" x = next(dataloader_iter)[0]\n",
" except StopIteration:\n",
" dataloader_iter = iter(dataloader)\n",
" x = next(dataloader_iter)[0]\n",
" x = x.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" loss = dsm_loss(score_model, x, sigmas)\n",
" loss.backward()\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" if (i % 1000) == 0:\n",
" print(f\"{i} steps loss:{loss}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9P_sAS2Fpa-o"
},
"outputs": [],
"source": [
"def sbm_sample(n_samples, score_model, sigmas, alpha=0.1):\n",
" sigma_T = sigmas[-1]\n",
" x_0 = torch.randn(n_samples, 2) * sigma_T\n",
" x_tk = x_0\n",
" K = 200\n",
" for t in range(len(sigmas) -1, -1, -1):\n",
" sigma_t = sigmas[t]\n",
" alpha_t = alpha*(sigma_t**2)/(sigma_T**2)\n",
" print(f\"t:{t}, sigma_t:{sigma_t}, alpha_t:{alpha_t}\")\n",
" for k in range(K+1):\n",
" u_k = torch.randn(n_samples, 2)\n",
" if (k == K) and t == 0:\n",
" u_k[:, :] = 0.0\n",
" with torch.no_grad():\n",
" sigma_t_dup = torch.ones((n_samples, 1)) * sigma_t\n",
" score = score_model(x_tk, sigma_t_dup)\n",
" x_tk = x_tk + alpha_t * score + np.sqrt(2 * alpha_t) * u_k\n",
" return x_tk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ixng3Ojiqp_J",
"outputId": "bfc4347c-b13f-4194-e205-1b598924362b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"t:199, sigma_t:1.0, alpha_t:0.10000000149011612\n",
"t:198, sigma_t:0.965883195400238, alpha_t:0.09329303354024887\n",
"t:197, sigma_t:0.9329304099082947, alpha_t:0.08703591674566269\n",
"t:196, sigma_t:0.90110182762146, alpha_t:0.08119844645261765\n",
"t:195, sigma_t:0.8703591227531433, alpha_t:0.07575250416994095\n",
"t:194, sigma_t:0.8406652808189392, alpha_t:0.07067181169986725\n",
"t:193, sigma_t:0.8119844794273376, alpha_t:0.06593187898397446\n",
"t:192, sigma_t:0.7842822074890137, alpha_t:0.061509858816862106\n",
"t:191, sigma_t:0.7575250267982483, alpha_t:0.057384420186281204\n",
"t:190, sigma_t:0.731680691242218, alpha_t:0.053535666316747665\n",
"t:189, sigma_t:0.7067181468009949, alpha_t:0.049945052713155746\n",
"t:188, sigma_t:0.6826071739196777, alpha_t:0.0465952567756176\n",
"t:187, sigma_t:0.6593188047409058, alpha_t:0.04347012937068939\n",
"t:186, sigma_t:0.6368249654769897, alpha_t:0.040554605424404144\n",
"t:185, sigma_t:0.6150985956192017, alpha_t:0.03783462941646576\n",
"t:184, sigma_t:0.5941134095191956, alpha_t:0.035297077149152756\n",
"t:183, sigma_t:0.5738441944122314, alpha_t:0.032929714769124985\n",
"t:182, sigma_t:0.5542664527893066, alpha_t:0.030721131712198257\n",
"t:181, sigma_t:0.5353566408157349, alpha_t:0.028660673648118973\n",
"t:180, sigma_t:0.5170920491218567, alpha_t:0.026738420128822327\n",
"t:179, sigma_t:0.4994505047798157, alpha_t:0.024945080280303955\n",
"t:178, sigma_t:0.48241087794303894, alpha_t:0.02327202633023262\n",
"t:177, sigma_t:0.4659525752067566, alpha_t:0.021711179986596107\n",
"t:176, sigma_t:0.4500557780265808, alpha_t:0.020255019888281822\n",
"t:175, sigma_t:0.4347013235092163, alpha_t:0.018896525725722313\n",
"t:174, sigma_t:0.4198707044124603, alpha_t:0.017629140987992287\n",
"t:173, sigma_t:0.40554606914520264, alpha_t:0.016446761786937714\n",
"t:172, sigma_t:0.39171016216278076, alpha_t:0.015343685634434223\n",
"t:171, sigma_t:0.3783462643623352, alpha_t:0.014314589090645313\n",
"t:170, sigma_t:0.365438312292099, alpha_t:0.013354516588151455\n",
"t:169, sigma_t:0.3529707193374634, alpha_t:0.012458832934498787\n",
"t:168, sigma_t:0.34092849493026733, alpha_t:0.01162322424352169\n",
"t:167, sigma_t:0.32929712533950806, alpha_t:0.01084365975111723\n",
"t:166, sigma_t:0.31806257367134094, alpha_t:0.010116379708051682\n",
"t:165, sigma_t:0.307211309671402, alpha_t:0.009437879547476768\n",
"t:164, sigma_t:0.29673025012016296, alpha_t:0.008804883807897568\n",
"t:163, sigma_t:0.2866067588329315, alpha_t:0.008214343339204788\n",
"t:162, sigma_t:0.27682867646217346, alpha_t:0.007663411553949118\n",
"t:161, sigma_t:0.2673841714859009, alpha_t:0.007149429526180029\n",
"t:160, sigma_t:0.25826188921928406, alpha_t:0.006669920869171619\n",
"t:159, sigma_t:0.24945081770420074, alpha_t:0.00622257124632597\n",
"t:158, sigma_t:0.24094036221504211, alpha_t:0.0058052255772054195\n",
"t:157, sigma_t:0.2327202409505844, alpha_t:0.005415871273726225\n",
"t:156, sigma_t:0.2247805893421173, alpha_t:0.005052631255239248\n",
"t:155, sigma_t:0.21711179614067078, alpha_t:0.004713753238320351\n",
"t:154, sigma_t:0.20970463752746582, alpha_t:0.0043976036831736565\n",
"t:153, sigma_t:0.20255018770694733, alpha_t:0.00410265801474452\n",
"t:152, sigma_t:0.19563983380794525, alpha_t:0.0038274943362921476\n",
"t:151, sigma_t:0.18896523118019104, alpha_t:0.0035707857459783554\n",
"t:150, sigma_t:0.1825183480978012, alpha_t:0.003331294748932123\n",
"t:149, sigma_t:0.17629140615463257, alpha_t:0.0031078660394996405\n",
"t:148, sigma_t:0.17027691006660461, alpha_t:0.002899422775954008\n",
"t:147, sigma_t:0.16446761786937714, alpha_t:0.0027049598284065723\n",
"t:146, sigma_t:0.15885651111602783, alpha_t:0.0025235391221940517\n",
"t:145, sigma_t:0.1534368395805359, alpha_t:0.002354286378249526\n",
"t:144, sigma_t:0.1482020765542984, alpha_t:0.0021963855251669884\n",
"t:143, sigma_t:0.14314588904380798, alpha_t:0.0020490745082497597\n",
"t:142, sigma_t:0.13826221227645874, alpha_t:0.0019116438925266266\n",
"t:141, sigma_t:0.1335451602935791, alpha_t:0.0017834309255704284\n",
"t:140, sigma_t:0.12898902595043182, alpha_t:0.0016638169763609767\n",
"t:139, sigma_t:0.12458833307027817, alpha_t:0.001552225323393941\n",
"t:138, sigma_t:0.12033778429031372, alpha_t:0.0014481182442978024\n",
"t:137, sigma_t:0.1162322461605072, alpha_t:0.0013509935233741999\n",
"t:136, sigma_t:0.11226677894592285, alpha_t:0.0012603829381987453\n",
"t:135, sigma_t:0.10843659937381744, alpha_t:0.0011758495820686221\n",
"t:134, sigma_t:0.10473708808422089, alpha_t:0.0010969857685267925\n",
"t:133, sigma_t:0.10116379708051682, alpha_t:0.0010234114015474916\n",
"t:132, sigma_t:0.09771241247653961, alpha_t:0.0009547715890221298\n",
"t:131, sigma_t:0.09437878429889679, alpha_t:0.000890735536813736\n",
"t:130, sigma_t:0.09115888178348541, alpha_t:0.0008309941622428596\n",
"t:129, sigma_t:0.08804883807897568, alpha_t:0.0007752598030492663\n",
"t:128, sigma_t:0.08504489064216614, alpha_t:0.0007232633652165532\n",
"t:127, sigma_t:0.08214343339204788, alpha_t:0.0006747543811798096\n",
"t:126, sigma_t:0.07934096455574036, alpha_t:0.0006294988561421633\n",
"t:125, sigma_t:0.07663410902023315, alpha_t:0.0005872786859981716\n",
"t:124, sigma_t:0.07401960343122482, alpha_t:0.0005478902021422982\n",
"t:123, sigma_t:0.07149428874254227, alpha_t:0.0005111432983539999\n",
"t:122, sigma_t:0.06905513256788254, alpha_t:0.00047686113975942135\n",
"t:121, sigma_t:0.06669919937849045, alpha_t:0.0004448783292900771\n",
"t:120, sigma_t:0.06442363560199738, alpha_t:0.0004150404711253941\n",
"t:119, sigma_t:0.06222570687532425, alpha_t:0.00038720385055057704\n",
"t:118, sigma_t:0.06010276824235916, alpha_t:0.0003612342698033899\n",
"t:117, sigma_t:0.05805225670337677, alpha_t:0.0003370064659975469\n",
"t:116, sigma_t:0.056071698665618896, alpha_t:0.00031440352904610336\n",
"t:115, sigma_t:0.054158713668584824, alpha_t:0.0002933166397269815\n",
"t:114, sigma_t:0.052310992032289505, alpha_t:0.00027364399284124374\n",
"t:113, sigma_t:0.05052630975842476, alpha_t:0.00025529079721309245\n",
"t:112, sigma_t:0.04880251735448837, alpha_t:0.00023816856264602393\n",
"t:111, sigma_t:0.047137532383203506, alpha_t:0.0002221947070211172\n",
"t:110, sigma_t:0.04552935063838959, alpha_t:0.0002072921779472381\n",
"t:109, sigma_t:0.043976034969091415, alpha_t:0.0001933891762746498\n",
"t:108, sigma_t:0.04247571527957916, alpha_t:0.00018041864677798003\n",
"t:107, sigma_t:0.041026581078767776, alpha_t:0.00016831803077366203\n",
"t:106, sigma_t:0.03962688520550728, alpha_t:0.0001570290041854605\n",
"t:105, sigma_t:0.038274943828582764, alpha_t:0.00014649714285042137\n",
"t:104, sigma_t:0.036969125270843506, alpha_t:0.00013667163148056716\n",
"t:103, sigma_t:0.03570786118507385, alpha_t:0.00012750514724757522\n",
"t:102, sigma_t:0.034489624202251434, alpha_t:0.00011895342322532088\n",
"t:101, sigma_t:0.03331294655799866, alpha_t:0.00011097524838987738\n",
"t:100, sigma_t:0.032176416367292404, alpha_t:0.00010353217658121139\n",
"t:99, sigma_t:0.03107866272330284, alpha_t:9.658833005232736e-05\n",
"t:98, sigma_t:0.030018357560038567, alpha_t:9.011018119053915e-05\n",
"t:97, sigma_t:0.028994228690862656, alpha_t:8.406653068959713e-05\n",
"t:96, sigma_t:0.02800503931939602, alpha_t:7.842822378734127e-05\n",
"t:95, sigma_t:0.02704959735274315, alpha_t:7.316807023016736e-05\n",
"t:94, sigma_t:0.02612675167620182, alpha_t:6.826072058174759e-05\n",
"t:93, sigma_t:0.025235392153263092, alpha_t:6.368249887600541e-05\n",
"t:92, sigma_t:0.02437444217503071, alpha_t:5.9411344409454614e-05\n",
"t:91, sigma_t:0.023542864248156548, alpha_t:5.542664439417422e-05\n",
"t:90, sigma_t:0.022739658132195473, alpha_t:5.1709208491956815e-05\n",
"t:89, sigma_t:0.021963853389024734, alpha_t:4.8241086915368214e-05\n",
"t:88, sigma_t:0.021214518696069717, alpha_t:4.5005581341683865e-05\n",
"t:87, sigma_t:0.020490746945142746, alpha_t:4.198707392788492e-05\n",
"t:86, sigma_t:0.01979166828095913, alpha_t:3.9171016396721825e-05\n",
"t:85, sigma_t:0.019116440787911415, alpha_t:3.654383181128651e-05\n",
"t:84, sigma_t:0.01846424862742424, alpha_t:3.409284909139387e-05\n",
"t:83, sigma_t:0.017834309488534927, alpha_t:3.1806262995814905e-05\n",
"t:82, sigma_t:0.017225859686732292, alpha_t:2.9673023163923062e-05\n",
"t:81, sigma_t:0.016638169065117836, alpha_t:2.7682868676492944e-05\n",
"t:80, sigma_t:0.016070527955889702, alpha_t:2.582618617452681e-05\n",
"t:79, sigma_t:0.015522253699600697, alpha_t:2.4094035325106233e-05\n",
"t:78, sigma_t:0.014992684125900269, alpha_t:2.247805787192192e-05\n",
"t:77, sigma_t:0.01448118221014738, alpha_t:2.0970464902347885e-05\n",
"t:76, sigma_t:0.013987131416797638, alpha_t:1.9563985915738158e-05\n",
"t:75, sigma_t:0.013509934768080711, alpha_t:1.8251834262628108e-05\n",
"t:74, sigma_t:0.013049019500613213, alpha_t:1.7027690773829818e-05\n",
"t:73, sigma_t:0.012603829614818096, alpha_t:1.588565282872878e-05\n",
"t:72, sigma_t:0.012173827737569809, alpha_t:1.4820208889432251e-05\n",
"t:71, sigma_t:0.01175849512219429, alpha_t:1.3826221220369916e-05\n",
"t:70, sigma_t:0.01135733351111412, alpha_t:1.2898902241431642e-05\n",
"t:69, sigma_t:0.010969857685267925, alpha_t:1.2033778148179408e-05\n",
"t:68, sigma_t:0.010595601983368397, alpha_t:1.1226677997910883e-05\n",
"t:67, sigma_t:0.010234113782644272, alpha_t:1.0473708243807778e-05\n",
"t:66, sigma_t:0.00988495908677578, alpha_t:9.771241820999421e-06\n",
"t:65, sigma_t:0.00954771600663662, alpha_t:9.115888133237604e-06\n",
"t:64, sigma_t:0.009221978485584259, alpha_t:8.504488505423069e-06\n",
"t:63, sigma_t:0.008907354436814785, alpha_t:7.934096174722072e-06\n",
"t:62, sigma_t:0.008603464812040329, alpha_t:7.4019603744091e-06\n",
"t:61, sigma_t:0.008309941738843918, alpha_t:6.905513146193698e-06\n",
"t:60, sigma_t:0.008026433177292347, alpha_t:6.442362973757554e-06\n",
"t:59, sigma_t:0.007752597332000732, alpha_t:6.010276592860464e-06\n",
"t:58, sigma_t:0.007488104049116373, alpha_t:5.607170351140667e-06\n",
"t:57, sigma_t:0.007232633884996176, alpha_t:5.231099294178421e-06\n",
"t:56, sigma_t:0.006985879968851805, alpha_t:4.880251708527794e-06\n",
"t:55, sigma_t:0.006747544277459383, alpha_t:4.5529354792961385e-06\n",
"t:54, sigma_t:0.006517339497804642, alpha_t:4.247571268933825e-06\n",
"t:53, sigma_t:0.00629498902708292, alpha_t:3.962688879255438e-06\n",
"t:52, sigma_t:0.0060802241787314415, alpha_t:3.696912699524546e-06\n",
"t:51, sigma_t:0.00587278651073575, alpha_t:3.4489621612010524e-06\n",
"t:50, sigma_t:0.005672425962984562, alpha_t:3.2176417334994767e-06\n",
"t:49, sigma_t:0.005478901322931051, alpha_t:3.00183614854177e-06\n",
"t:48, sigma_t:0.00529197882860899, alpha_t:2.800504034894402e-06\n",
"t:47, sigma_t:0.005111433565616608, alpha_t:2.612675189084257e-06\n",
"t:46, sigma_t:0.004937048070132732, alpha_t:2.437444436509395e-06\n",
"t:45, sigma_t:0.004768611863255501, alpha_t:2.2739659470971674e-06\n",
"t:44, sigma_t:0.004605921916663647, alpha_t:2.121451643688488e-06\n",
"t:43, sigma_t:0.004448782652616501, alpha_t:1.979166654564324e-06\n",
"t:42, sigma_t:0.004297004546970129, alpha_t:1.8464248796590255e-06\n",
"t:41, sigma_t:0.004150404594838619, alpha_t:1.7225859210157068e-06\n",
"t:40, sigma_t:0.004008806310594082, alpha_t:1.6070528090494918e-06\n",
"t:39, sigma_t:0.003872038796544075, alpha_t:1.4992684782555443e-06\n",
"t:38, sigma_t:0.0037399372085928917, alpha_t:1.3987130387249636e-06\n",
"t:37, sigma_t:0.00361234275624156, alpha_t:1.30490195715538e-06\n",
"t:36, sigma_t:0.003489101305603981, alpha_t:1.2173827599326614e-06\n",
"t:35, sigma_t:0.0033700643107295036, alpha_t:1.1357333278283477e-06\n",
"t:34, sigma_t:0.0032550885807722807, alpha_t:1.0595601906970842e-06\n",
"t:33, sigma_t:0.003144035581499338, alpha_t:9.884960263661924e-07\n",
"t:32, sigma_t:0.0030367712024599314, alpha_t:9.221980121765228e-07\n",
"t:31, sigma_t:0.0029331662226468325, alpha_t:8.603464607404021e-07\n",
"t:30, sigma_t:0.0028330960776656866, alpha_t:8.026433420127432e-07\n",
"t:29, sigma_t:0.0027364399284124374, alpha_t:7.488103506148036e-07\n",
"t:28, sigma_t:0.0026430815923959017, alpha_t:6.985880531829025e-07\n",
"t:27, sigma_t:0.002552908146753907, alpha_t:6.517340125355986e-07\n",
"t:26, sigma_t:0.0024658110924065113, alpha_t:6.080224466131767e-07\n",
"t:25, sigma_t:0.0023816856555640697, alpha_t:5.672426937053388e-07\n",
"t:24, sigma_t:0.002300430089235306, alpha_t:5.291978482091508e-07\n",
"t:23, sigma_t:0.0022219468373805285, alpha_t:4.937048174724623e-07\n",
"t:22, sigma_t:0.0021461411379277706, alpha_t:4.605921901656984e-07\n",
"t:21, sigma_t:0.00207292172126472, alpha_t:4.297004636555357e-07\n",
"t:20, sigma_t:0.0020022003445774317, alpha_t:4.008806229194306e-07\n",
"t:19, sigma_t:0.0019338917918503284, alpha_t:3.739937426416873e-07\n",
"t:18, sigma_t:0.0018679136410355568, alpha_t:3.489101345621748e-07\n",
"t:17, sigma_t:0.001804186380468309, alpha_t:3.255088643072668e-07\n",
"t:16, sigma_t:0.0017426334088668227, alpha_t:3.0367712611223396e-07\n",
"t:15, sigma_t:0.0016831803368404508, alpha_t:2.833096175436367e-07\n",
"t:14, sigma_t:0.0016257556853815913, alpha_t:2.643081700171024e-07\n",
"t:13, sigma_t:0.0015702900709584355, alpha_t:2.465810950980085e-07\n",
"t:12, sigma_t:0.0015167169040068984, alpha_t:2.300430281820809e-07\n",
"t:11, sigma_t:0.0014649713411927223, alpha_t:2.1461410426582006e-07\n",
"t:10, sigma_t:0.0014149913331493735, alpha_t:2.002200574224844e-07\n",
"t:9, sigma_t:0.001366716343909502, alpha_t:1.867913539399524e-07\n",
"t:8, sigma_t:0.001320088398642838, alpha_t:1.742633344292699e-07\n",
"t:7, sigma_t:0.001275051268748939, alpha_t:1.6257557433618786e-07\n",
"t:6, sigma_t:0.001231550588272512, alpha_t:1.5167169920005108e-07\n",
"t:5, sigma_t:0.0011895340867340565, alpha_t:1.4149912885841331e-07\n",
"t:4, sigma_t:0.001148951007053256, alpha_t:1.3200885007336183e-07\n",
"t:3, sigma_t:0.0011097524547949433, alpha_t:1.2315506126014952e-07\n",
"t:2, sigma_t:0.0010718912817537785, alpha_t:1.1489509432749401e-07\n",
"t:1, sigma_t:0.0010353218531236053, alpha_t:1.0718913756591064e-07\n",
"t:0, sigma_t:0.0010000000474974513, alpha_t:1.0000001537946446e-07\n"
]
}
],
"source": [
"samples_pred = sbm_sample(n_samples=100000, score_model=score_model, sigmas=sigmas)\n",
"samples_pred = samples_pred.cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "dtTR7L35w4DK",
"outputId": "23907bcb-80c4-4009-c3e3-1dfdb2621b87"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQoAAAEICAYAAACnA7rCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhVklEQVR4nO2de7RkVX3nv9+quu/u200/aLC7BRXCgEZFGVqWJnGCOkAMGKMZnNFIlHElkaXJcmXGGWfFyYyu6MpEjQsmgNGoiYMwIIraiKAyaAzKY+EDGrQhIt08mm76dt/b91G3qn7zx/7tOqfOrbrn0lW3btXt72etWnUe+5y9696q3/nu3/7t/aOZQQghFqOw0g0QQvQ+MhRCiFxkKIQQuchQCCFykaEQQuQiQyGEyEWGYoUg+VmSH/LtXyP5UJfqNZKndKOuo4XkJSS/t9LteDaQfC7JKZLFlW7LciBDsQgkf0Fyxr8AT/mPe02n6zGz75rZaUtoz7L+gEi+kOQ3ST5DcoLkPSQvWK76lgOSt5OcJTlJ8rB/hveTHFrOes3sl2a2xsyqqXZcupx1dhMZinx+28zWAHgZgLMA/LdsAZKlrrdqefgqgFsBnADgeADvAXB4RVt0dFxmZmsBnAjgfQAuBrCTJFe2Wf2LDMUSMbO9AG4G8CKgLuHfTfLnAH7ux15P8j5/Gn+f5Ivj9STPJHmvP+muBTCcOvdqkntS+9tJfonk0yQPkLyc5OkArgRwjiucCS87RPJ/kfylq54rSY6k7vVnJJ8g+TjJd7T6fCQ3AXgegE+ZWdlf/2Rm3/Pzx5H8mrfpoG9vS11/O8kP+eeeIvlVkhtJfsGf7HeRPDlV3ki+h+QjJPeT/CuSTb+PJP8VyVtd6TxE8veW+D87Yma3A7gQwDkAfsvvV3CV8bD/fa8jucHPnexte7v/TfeT/ECqLWeTvNs/01MkP5a5rkTywwB+DcDl/re4nOQVJP8687luIvmnS/ksK46Z6dXiBeAXAF7j29sB3A/gf/q+ITx9NwAYAXAmgH0AdgAoAni7Xz8EYBDAowD+FMAAgDcBmAfwIb/XqwHs8e0igB8B+DiAMQSD8io/dwmA72Xa+HEAN3k71iKogr/0c+cBeArBuI0B+D/e7lOafFYiGLyvAXgDgC2Z8xsB/C6AUa/n/wL4cur87QB2A3gBgHUAHgDwMwCvAVAC8HkAf58qbwC+4+1+rpe9NPs5vd2PAfgDv8+ZAPYDOKPF/+z2eJ/M8TsAfNS33wvgTgDb/P9zFYBr/NzJ3rZP+f/1JQDmAJzu5/8ZwNt8ew2AV2SuKzVrB4CzATwOoOD7mwBMZ//Ovfpa8Qb08gvhhz4FYALhh/6/AYz4OQPwm6myfws3IqljDwH4DQC/7l8Sps59H80NxTkAno5fuMz96j8g3yeAIwBekDp2DoB/8e3PAPhI6tyvoIWh8PPbAFwO4GEANf9xndqi7EsBHEzt3w7gA6n9vwZwc2r/twHcl9o3AOel9v8YwLeynxPAvwPw3UzdVwH4YIt2NfxAU8e/iKCWAGAXgHNT505EMNyl1A9+W+r8DwFc7Nt3APgLAJsy94/XNTUUqXpf69uXAdi50t/xpb7U9cjnDWa23sxOMrM/NrOZ1LnHUtsnAXifdzsmvGuwHcBz/LXX/BviPNqivu0AHjWzyhLathnhCX9Pqs5v+HF4vek2tqoTAGBme8zsMjN7gX+eIwhKACRHSV5F8lGShxF+MOszXv6nUtszTfazjuBs257TpFknAdiR+bv+BwQ/yrNhK4BnUve8MXW/XQCqALakyj+Z2p5Otf2dCAb3Qe9Ovf5ZtOFzAN7q228F8A/P6hOsIDIU7ZH+4T8G4MNuVOJr1MyuAfAEgK0ZZ9pzW9zzMQDPbeEgzU713Y/wA3xhqs51Fpyv8Hq3L6HOhRWZPQbgCrhPBsEpeBqAHWY2jqCSgKBqjpZs2x5vUuYxAP8v83ddY2Z/tNRKSG4H8HIA303d8/zMPYct+KEWxcx+bmZvQXD2fhTA9STHmhVtcuwfAVxE8iUATgfw5aV+hpVGhqJzfArAH5LcwcAYyd8iuRahX1sB8B6SAyTfiNBnbcYPEX7gH/F7DJN8pZ97CsA2koMAYGY1r/fjJI8HAJJbSf5bL38dgEtInkFyFMAHWzXenZV/QfIUd/ZtAvAOhL48EPwSMwAm3PHX8l7Pgj/zercj+A2ubVLmawB+heTb/G83QPJfu3N3UVwF/QaAryD8XXf6qSsBfJjkSV5uM8mLltJgkm8ludn/9hN+uNak6FMAnp8+YGZ7ANyFoCRuyKjTnkaGokOY2d0A/iNCH/8ggmPvEj9XBvBG338God/9pRb3qSL0508B8EsAe7w8AHwbwaH6JMn9fuw/e113epfgNoQnP8zsZgCf8Ot2+3srygj97NsQhkR/iuDEu8TPfwLBubcfwXh8Y5F7LZWvALgHwH0Avg7g09kCZjYJ4HUIQ5yPI3QJPorghGzF5SQnEX6snwBwA4I/JP6g/wbBAfxNL3cnghN6KZwH4H6SU36fi1v84P8GwJt8hOiTqeOfA/Cr6KNuB+DONSG6DUlDcJTuXum2dBOSv47QBTnJ+ujHJ0UhRJcgOYDQxfq7fjISQAcMBUNw0HdIPkDyfpLvbVKGJD9JcjfJH5N8Wbv1CtFPuE9lAmEo9hMr2pijoBOhxxUA7zOze91xdw/JW83sgVSZ8wGc6q8dCDEHS+0TilWImR1T4dRmtgsheKwvaVtRmNkTZnavb08ijElvzRS7CMDnLXAnwvj7ie3WLYToDh2dzMQQy38mgB9kTm1FY3DNHj/2RJN7vAvAuwCgiOLLRzHeySYKIVJM4uB+M9ucV65jhoJh+vUNAP7EzI56xqGZXQ3gagAY5wbbwXM71EIhRJbb7PpFo3UjHRn1cG/uDQC+YGbN4gP2ojEKb5sfE0L0AZ0Y9SBCoMwuM/tYi2I3Afh9H/14BYBDZrag2yGE6E060fV4JYC3AfgJyfv82H+FzyswsysRQmcvQIgOnEaYMiyE6BPaNhQWFjZZdKjLg0ve3W5dQoiVQZGZQohcZCiEELnIUAghcpGhEELkIkMhhMhFhkIIkYsMhRAiFxkKIUQuMhRCiFxkKIQQuchQCCFykaEQQuQiQyGEyEWGQgiRiwyFECIXGQohRC4yFEKIXGQohBC5dGoV7s+Q3Efypy3Ov5rkIZL3+evPO1GvEKI7dCqvx2cBXA7g84uU+a6Zvb5D9QkhukhHFIWZ3QHgmU7cS4hlh4XwOorzLBbBYnGZGta7dNNHcQ7JH5G8meQLu1ivEKJNOpp7dBHuBXCSmU2RvADAlxEymy8gnXt0GKNdap4QYjG6oijM7LCZTfn2TgADJDe1KHu1mZ1lZmcNYKgbzRPHGlYLr2wXw/dZIFhgcp6FepfDagar2THXBemKoSB5gqceBMmzvd4D3ahbCNE+Hel6kLwGwKsBbCK5B8AHAQwA9ZSCbwLwRyQrAGYAXOzZw4ToPq4iWAgJ7lgcbDheL2a1hZcOBhVh8xW/duHtrWZxoxOt7Qk6YijM7C055y9HGD4VQvQh3XJmCtEdssOa/lRnaaBJUTZeExXGyHDYr6YUQdHLlOdDmcGgQmx+3utJBDJdUVglnLNqtbGePlQaCuEWQuQiRSH6m+xTOiqIOCLBgcZ9ABx0dVFwX8Wwj65FVeBqIa0SUAk+CQy72qCrkUOHvdqUSoiP3+gL8ar7WVlIUQghcpGiEP1N5qmcKInMyMZwEpPDkn/tx0bC+0BQGOZKg64OqmtH6tcUD01n6g1qg/Ge5XJyKvoxvC11JeFl6vt9hBSFECIXKQrRf6RHNjKjGhwIX+m6snD1UPdLAMC6cQBAdV3zKQLljcEPUZyp1I/VRsI1rAYlUZyaCyeG/P6HjiTNG3L14iMiNjnV2KbY9D5SFlIUQohcZCiEELmo6yH6j5QDc8HELJfz5sOXhRg8NZZ0M2rjwUlpHo49uzl0FcprwnNzbl24tjST/DwGZkKXo1YM54YnwhDq4MHgoCwMJmULBw77RbE75M5SD+CyWgz/Ttre690QKQohRC5SFKL3yQQoNQRPxW1/j6HVHPKgKXcs1saG69fMnhDUxdy6cM305qASzKuZO65+96Qec5XhPsu5deGns86dm6VSUrbEdQCAwr6JcK07VM1mGtusgCshxGpCikL0NbFvH5/a9WAqf6+tGwMAlDcnPorpzeGJPrMxqIByEAAobwm+g8ENswCAaiV5jlo1lJ1/KigTcwUxtS34H9bsTdpUmvSh0zVepwdacdT3Z8L9a3PzyUU9HtYtRSGEyEWKQvQ+izxlk/6+T+CKvgmf8FUdDz6KymjyTDxyQlADMyeG+xY2BwWweV0IjDphzSQAYMNQErb9s4nNAIB90Y9RDcqiMBfuNTiV+E1K0+HcQMV9KnG6emyjT1kvxMlnSJSR9ejghxSFECIXKQrRN9QnWdVSi8R4LETdNxHx+InKSDhfHk+e+FWf67X25EMAgPGR4DN41fGPAADOH/8xAOD+ua31a1645nEAwPW1MwEA+10kVA+Hm82uT565w/vDtsVp7FHlRH+KH6/VmqwG2aO+CikKIUQu3co9SpKfJLmb5I9JvqwT9QohukO3co+ej5Dw51QAOwD8rb8LkU9mHczCQPK1ra+GHR2Dvu5EDM+ujHqY9noiy9x8uM/WTSHk+nXjPwllLRy/YOzBetnrDr8UALB5NDg8D0ysAQCU14fuw8i+5P6VsVDnYExIUcrMGq01rsYVNnt75e5u5R69CMDnLXAngPUkT+xE3UKI5adbzsytAB5L7e/xY09kCyqloFiAxclUjWtQAsn6E4hP6RjK7aHVhYpP5kotR1FeH8puGQsh1aOlMDz6i3IYAr1kfB8A4M65ZIWrF42Er+/102emm4TacNgozSTP3OKsn4zralZ8zDOuiuXOzBheDgDmQVi9Ss85M5VSUIjeo1uKYi+A7an9bX5MiFyya082TMn2vn3dRxHXuzwuqNHSTFz5OvmqsxKe9JOz4UG0fmCmob6HK8EP8WRlS/1YkaGe044LamPf02HFq6EDcY3OZKizUMkEWEVlUcsojVTekLi257EecHUTgN/30Y9XADhkZgu6HUKI3qRbuUd3ArgAwG4A0wD+oBP1imOD+sSvJgFXhWH/Cjd5SgNAdThcMzCZOuiPxyNPhpGL75ROBQBsOjkoiWsmzgIADBeSSVuz7uT4p0eeHw5Mh3oHJ8Lu4OGk3kLZQ8MzK3fHlcBrh0M9MZNY+jP2Kt3KPWoA3t2JuoQQ3Uch3KJviEqiYfm7+CSOmbymw1O8MBme3gMeT1HakFwz9sugPuY2hmMTI2Eq+h1rTgEATJXDtYdnksVuypVQtjbvMRL7w3vRZ5QX51Mqp+xt8tENOxLaZNONCsOahXD3KD036iGE6D2kKETfke7PszjYeNJ9FPSneGkgPPmHDyblClVXB3H5ulpQEA9WnhOuHXB/w+Ek+GLwUHimDrlwGf+XoAbGngwHhp5O1ALjgjTTPppSn17uPhaPJu3VKMxmSFEIIXKRoRBC5KKuh+h92Pp5lh5iBJCkEIzDpPOhmzI4kSQRLs75qtgMXYuZTaELMvRM2I/h3oUkoyBKHmE9OBkngYV6hx8P466WyuvBWU8lGJ2YMYFxdMbG4KqWn6r3kKIQQuQiRSF6n4zTb0F2sFSZ7OSqOPl7oJI4QItr4hyiEOY9fCCUiutqWrExzwcADEyH+w9MBplRT2AcJ59NJEmK7ZmJxra5kqiVy+hXpCiEELlIUYi+o3mgkod5e4BVDMCyGc/OVUuHWIcn+1A5lKmNBIUxOOCKohTei9OpEGtXGQXP2cGZucbq02rB667FuksDjWX7aFg0IkUhhMhFikL0H02eyHFRm2xWrrrPopyoA46GBWnox4pxIZmY+TwuXZe6BoOuCg4c9JsEhVGbDBO80ETl1JVPHJnpQyURkaIQQuQiRSFWB/WntYdnu5KoxywwWfzWXAUU1oZp5nXlMO/vMQajmHqOMi4sU228pglxdCNZjKa3p5AvBSkKIUQuUhRiVVGfil4IqqDmU76Zzi7mC/LW/QuZRW/iAjM2lRrZKMQJZNa432R5vtWkJCJSFEKIXGQohBC5dCql4HkkH/KUge9vcv4Skk+TvM9fl3aiXiHAQsOkMRYIFgirGaxmYLEYXn6cBcLK5TBRq1oFqlWwUAALBVhlPrxm52Czc8FB6q94Tb1M3K9Ww8vrS79WE237KEgWAVwB4LUIiX3uInmTmT2QKXqtmV3Wbn1CiO7TCWfm2QB2m9kjAEDyiwgpBLOGQojOkwliyubwjHkyFnMs1qaiU9OViQdINVUFLYKm0hPVVpuaADrT9WiVLjDL73om8+tJbm9yHkBIKUjybpJ3z2OuVTEhRBfpljPzqwBONrMXA7gVwOdaFVRKQdEW0a+wGO7XiP6FhbdwH0PKR7HgvhnfSN1XUa0urQ19RicMRW66QDM7YGZRHvwdgJd3oF4hRJfohKG4C8CpJJ9HchDAxQgpBOuQPDG1eyGAXR2oV4ijo5U6aKUelnKPVU7bzkwzq5C8DMAtCIH2nzGz+0n+DwB3m9lNAN5D8kIAFQDPALik3XqFEN2DZr3roR3nBtvBc1e6GWK1En0Mx5AyyHKbXX+PmZ2VV06RmUKIXGQohBC5aPaoOHY5hrsczxYpCiFELjIUQohcZCiEELnIUAghcpGhEELkIkMhhMhFhkIIkYsMhRAiFxkKIUQuMhRCiFxkKIQQuchQCCFykaEQQuQiQyGEyEWGQgiRS7dSCg6RvNbP/4DkyZ2oVwjRHdo2FKmUgucDOAPAW0iekSn2TgAHzewUAB8H8NF26xVCdI9OKIp6SkEzKwOIKQXTXIQk6c/1AM4lyQ7ULYToAt1KKVgvY2YVAIcAbGx2M6UUFKL36DlnplIKCtF7dCWlYLoMyRKAdQAOdKBuIUQX6EpKQd9/u2+/CcC3rZczDwkhGuhWSsFPA/gHkrsRUgpe3G69Qoju0ZG8Hma2E8DOzLE/T23PAnhzJ+oSQnSfnnNmCiF6DxkKIUQuMhRCiFxkKIQQuchQCCFykaEQQuQiQyGEyEWGQgiRiwyFECIXGQohRC4yFEKIXGQohBC5yFAIIXKRoRBC5CJDIYTIRYZCCJGLDIUQIhcZCiFELm0ZCpIbSN5K8uf+flyLclWS9/kru/CuiLDQ+iXECtLuN/D9AL5lZqcC+JbvN2PGzF7qrwvbrFMI0WXaNRTpVIGfA/CGNu/Xfyzlie9lWCw2fS12DxYIFihlIVaUdr95W8zsCd9+EsCWFuWGPU3gnSTfsNgNlVJQiN4jd7l+krcBOKHJqQ+kd8zMSLZK6nOSme0l+XwA3yb5EzN7uFlBM7sawNUAMM4NPZ8kiIWQa5ml4XDAauGtWk3KlAYaztUpFsN7uRxvtuC+tXguOdH8XkIsI7mGwsxe0+ocyadInmhmT5A8EcC+FvfY6++PkLwdwJkAmhoKIUTv0W7XI50q8O0AvpItQPI4kkO+vQnAKwE80Ga9Qogu0m6msI8AuI7kOwE8CuD3AIDkWQD+0MwuBXA6gKtI1hAM00fMrH8NRSuHYqYrwMHBZDt2MepFa34rv1csW7N0oRbX2sJ2xLqz3RJ1U0SHaMtQmNkBAOc2OX43gEt9+/sAfrWdeoQQK0tHco+uWvyJHB2LDacyT/romGQp/EnTiiI6Njnkx6Znwv6asbA/OeX3SFfgdUanqCd/pyuK2txsk+bS64sVS0mIzqCBeSFELlIUS8BSvoO6kojvUS1EBVF027t2TXJNfLJ7GR63zq91P8TQUNifn08qdT8G5uYaypoPlxaaKJZ0O4XoJFIUQohcpCgWo8nIQz3AKh4bCYFWdX/DQAiuqo2P1K8pzASlUN6yFgBQPOLKwc0054IiKEynIlHnK16fF6o1Kgqk21RvbiU2sqH9QrSLFIUQIhcpijStYiTSx+OTPI5uRP9CIRy3YVcUI4kPYfY5wV9RnA7KYW6Th3sXgxYoTQaFURhLxV64T6JwJKgMTk6H/dHRUE+lUi9rMz6KkhmdsSqE6AhSFEKIXKQompCNkeDAwj8T3TeB8aAWasNBDcxvCr6J+TXJNbPrgz22qDr8yV+a9diIdeF4aTrxKQwcCXJgaMrjJTyOAlE1MKUeYnvjqId8E6LDSFEIIXKRoRBC5KKuR5r6WhJhd8FQKAAM+toS7sQ0H76sHBe6ItPHh/Pzo4kNroz6u4+YWozVGg73L/iI59jjyTW1Ae+eTIWLCu48jSVYm0mazcZ2WkVdD9FZpCiEELlIUQALhkULWedlynFYHw4thad3bSzsV4Z9v+QTs1IiZPJkn8i1xqVKze8XZ4xXXFlUknbMr3WVUAuKYnRvLBvkCeeSla/oQV5mjStl0duQXm1LiKNBikIIkYsUBbBgONGyE70KKXsaJ2ANhT+dDcZJYo1KYjq1ymhcSdSKYWPj1kPhVq4sJg8H1TBTGa5fMzgRzg0eCu/FTUG5jPpwqaUnkMWJaFH5aHhUdBgpCiFELlIUKeqjG9FnEfv2zYKb/FhhLoRSz60LqqDsvoXaQHJJdWN4+g+MhPeNo0cAAPPVcK+oKGqjiS+hsD/8a6LPozDfOIWc4+P1bTt82Dcay8g3ITpFuykF30zyfpI1XyezVbnzSD5EcjfJVtnEhBA9SruK4qcA3gjgqlYFSBYBXAHgtQD2ALiL5E29uMBu3TdRcvsZ1UNqAlb9WIyfGAvSoe5ecPGRVhQsBp/BC7bsBwBMlYO/oWYeKzEU6q08lfgoIvE+xdnoGwkHOFdZULY+IlNXQh46XplfUFaIZ0O7i+vuAgBy4ZqSKc4GsNvMHvGyX0RIRdhzhkII0Zxu+Ci2Angstb8HwI5WhUm+C8C7AGAYo8vbsnqljT2wBZPCRpNFaOKkrOqa8PQuVIJfoOghDNGVUBtK/AU2E/7Mjx7YEO7nwyDVWqi3fCSohFI5MbjlsMYNxvZ6DMZA4wI29SXywo1CPb6oTX1pvIyy0GiIOFraSiloZgsS/rRLv6UUFOJYoK2UgktkL4Dtqf1tfkwI0Sd0o+txF4BTST4PwUBcDODfd6Heo6bu1IyJNqopye6h06V9YUiyujGsRzGyLzgM50dDkNbIE0k3YmZLuM+shS4MR4MjsjjoXYRyOF8dTeoZfzh0FwangqgqzPuKVwfD0CoqqaHPbA6Q+nF1OURnaHd49HdI7gFwDoCvk7zFjz+H5E4AMLMKgMsA3AJgF4DrzOz+9pothOgm7Y563AjgxibHHwdwQWp/J4Cd7dTVDRZkBPMncW0mmdJdiKttx6ndPqW7NtioAKpDyb2Gn3YH6GT4c8+vDQqi4gpi4FDYH3s8qXrwcGNmsKKv5I24CncpcbjWV+ZOPohvKOBKdAaFcAshcjl2Q7hTQ6L1nJ3+9GYh9umL2avqq19zJkzOKvhw5cDa4JsYmAznC/PJitpTWwt+LOyXjvhEL/dNMCYSO5wM8gwdCgeH94dh0LgaN4Y9qOrQZNKoasYHEfORZHORCnGUSFEIIXI5dhVFaiRg4RPXl5SLox/pgKyyy4IBf/fVuEsHg8KojYTgqaGJJMQ6KofyWl9QxqsulH1Eoxre06twFz17WGnCc3ZERXEkjHo0+CWicvBoL9Nq3KLDSFEIIXI5dhXFUohKIjUaYu6TYJzSfcQzeLn6KMyGP2lxOpkVNuAqo+AzxwrlcI/SkaA6zBe9YTXxUcTp65z3WIuDE2F/KPg+GjKFzWcmiElJiA4jRSGEyEWKognZ0Y+0P4Alzy06ORX2fbk8xtGIWc8VWkvUQXEmXD9y2Jex8+hNzgU/R4zFsFJqJObpg2HDp7PXJ3z5/ZtNHdckMLFcSFEIIXKRoRBC5KKuxyIkK16llquKwUxFdyrG8O5nGsO/OZi6xrsl9MTDnMysyRkdpKmJXtFZGbsayIaXp4ZsF3RD1OUQHUaKQgiRixRFM7JP5GbBWeXGrFwN62oiswJ2DPuOuUFj0JZPLIuOyvTKWtkVtLOrVzWQWaFLikJ0GikKIUQuUhRLoOEpnpnCbTVXC+4niKqgviI2AHhAVFQS9QCp+OSPSiO1DmbNy8Q8qPUh22JjeHnDfYRYJqQohBC5SFE8W/zpHX0V9Sd8DLDyFbbNp6GnqS+Mk/Ep1P0PqSCt7NT3elll/xIrgBSFECKXthQFyTcD+O8ATgdwtpnd3aLcLwBMInTsK2bWMv1gv1GPtYgjFk38Bcm07xgLka8KkmuU5UusPMueUjDFvzGz/W3WJ4RYAbqRUlAI0ed0y5lpAL7JkEvvKs8GtqpYMHOz4WSjA3QBmu0pepxupRR8lZntJXk8gFtJPmhmd7Sor/u5R4UQi9KNlIIws73+vo/kjQgZzpsair7PPXo0qkBKQvQ4yz48SnKM5Nq4DeB1CE5QIUSfsOwpBQFsAfA9kj8C8EMAXzezb7RTrxCiuyx7SkEzewTAS9qpRwixsigyUwiRiwyFECIXGQohRC4yFEKIXGQohBC5yFAIIXKRoRBC5CJDIYTIRYZCCJGLDIUQIhcZCiFELjIUQohcZCiEELnIUAghcpGhEELkIkMhhMhFhkIIkYsMhRAiFxkKIUQu7S6u+1ckHyT5Y5I3klzfotx5JB8iuZvk+9upUwjRfdpVFLcCeJGZvRjAzwD8l2wBkkUAVwA4H8AZAN5C8ow26xVCdJG2DIWZfdPMKr57J4BtTYqdDWC3mT1iZmUAXwRwUTv1CiG6Sydzj74DwLVNjm8F8Fhqfw+AHa1ukk4pCGDuNrt+NSYL2gRgNWZ2X62fC1i9n+20pRTqSO5Rkh8AUAHwhWfTwmakUwqSvNvMzmr3nr2GPlf/sVo/G8m7l1Ku7dyjJC8B8HoA55pZs1yhewFsT+1v82NCiD6h3VGP8wD8JwAXmtl0i2J3ATiV5PNIDgK4GMBN7dQrhOgu7Y56XA5gLYBbSd5H8kqgMfeoOzsvA3ALgF0ArjOz+5d4/6vbbF+vos/Vf6zWz7akz8XmvQUhhEhQZKYQIhcZCiFELj1tKJYaIt6PkHwzyftJ1kj2/bDbag3TJ/kZkvtIrqp4HpLbSX6H5AP+PXzvYuV72lBgCSHifcxPAbwRwB0r3ZB2WeVh+p8FcN5KN2IZqAB4n5mdAeAVAN692P+spw3FEkPE+xIz22VmD610OzrEqg3TN7M7ADyz0u3oNGb2hJnd69uTCCOSW1uV72lDkeEdAG5e6UaIpjQL02/5pRO9BcmTAZwJ4AetynRyrsdR0e0Q8W6ylM8mxEpCcg2AGwD8iZkdblVuxQ1FB0LEe5a8z7aKUJh+H0JyAMFIfMHMvrRY2Z7ueiwxRFysPArT7zNIEsCnAewys4/lle9pQ4EWIeKrAZK/Q3IPgHMAfJ3kLSvdpqOlzTD9nobkNQD+GcBpJPeQfOdKt6lDvBLA2wD8pv+27iN5QavCCuEWQuTS64pCCNEDyFAIIXKRoRBC5CJDIYTIRYZCCJGLDIUQIhcZCiFELv8f8a1v3jEvCWgAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.hist2d(samples_pred[:,0], samples_pred[:,1], range=((-2, 2), (-2, 2)), cmap='viridis', rasterized=False, bins=100, density=True)\n",
"plt.gca().set_aspect('equal', adjustable='box')\n",
"plt.xlim([-2, 2])\n",
"plt.ylim([-2, 2])\n",
"plt.title('Predicted Sample Density')\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment