Skip to content

Instantly share code, notes, and snippets.

@jvanvugt
Created November 16, 2018 22:05
Show Gist options
  • Save jvanvugt/107d0703768a88aad586a843447129c6 to your computer and use it in GitHub Desktop.
Save jvanvugt/107d0703768a88aad586a843447129c6 to your computer and use it in GitHub Desktop.
Variational AutoEncoder
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import math\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from torchvision import datasets, transforms\n",
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"PROBA_OUT = False # Output gaussians over pixel values\n",
"SSE_RECON_LOSS = False # only if not PROBA_OUT: use SSE loss, otherwise BCE\n",
"DO_SIGMOID = not(PROBA_OUT or SSE_RECON_LOSS)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class VAE(nn.Module):\n",
" def __init__(self, z_dim=2, hidden_dim=512, probabilistic_output=False):\n",
" super().__init__()\n",
" self.encoder = nn.Sequential(\n",
" nn.Linear(28*28, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim, z_dim*2),\n",
" )\n",
" self.decoder = nn.Sequential(\n",
" nn.Linear(z_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim, (bool(probabilistic_output)+1)*28*28),\n",
" )\n",
" self.z_dim = z_dim\n",
"\n",
" def forward(self, x):\n",
" z_mu_sigma = self.encoder(x).view(x.shape[0], 2, self.z_dim)\n",
" z_mu, z_sigma = z_mu_sigma[:, 0], z_mu_sigma[:, 1]\n",
" z_sigma = torch.exp(z_sigma)\n",
" z_sample = torch.randn(x.shape[0], self.z_dim).cuda()\n",
" z_sample = z_sample * z_sigma + z_mu\n",
" return self.decoder(z_sample), (z_mu, z_sigma)\n",
"\n",
"\n",
"def gaussian_log_prob(x, mu, sigma):\n",
" return torch.log(1 / torch.sqrt(2*math.pi*sigma)) - (x-mu)**2/(2*sigma)\n",
"\n",
"vae = VAE(probabilistic_output=PROBA_OUT).cuda()\n",
"optim = torch.optim.Adam(vae.parameters())\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),\n",
" batch_size=128, shuffle=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c65c9d62b2e54fb29247e64071545e78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=50), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0; last KL=334.67828369140625; last RL=15596.2646484375\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAEk9JREFUeJzt3V+M1fWZx/HPI3/kr8CAkEGoNkI2VuNKJLixZsU0JK4X1V40qVdusgm9WBNNvCjxBthkE7Np7V5004SNRpp0bWpsF9KL7RKDaM2GiMYUFLsYBPkzDH+F4d8MA89ecNiM/J4vc35zfufM+X3n/UrIzDzzPed8fzMPDz/O95+5uwAA9XfbeHcAAFANCjoAZIKCDgCZoKADQCYo6ACQCQo6AGSCgg4AmaCgA0AmWiroZvakmf3FzL4ws3VVdQoYb+Q26sjGulLUzCZJ+l9JayQdlvShpGfd/bPqugd0HrmNuprcwmNXSfrC3fdLkpn9RtLTkpJJb2bsM4C2cner4GnIbXSdZnK7lbdc7pJ0aMTXhxsxoO7IbdRSK3fo0b8WhbsUM1sraW0LrwN0GrmNWmqloB+WtHTE10skHb25kbtvkrRJ4r+lqA1yG7XUylsuH0pabmbfNrOpkn4kaWs13QLGFbmNWhrzHbq7D5vZ85L+KGmSpNfd/dPKegaME3IbdTXmaYtjejH+W4o2q2iWS2nkNtqt3bNcAABdhIIOAJmgoANAJlqZtggAY2LW/FBHmbZlpMYPOzmuWDXu0AEgExR0AMgEBR0AMkFBB4BMMCgK4JaiQcnUQGUUv+224n1jFEtJtZ00aVIhdu3atbBtFL969WrTbVPP220DqNyhA0AmKOgAkAkKOgBkgoIOAJmgoANAJpjlAmSs1dkokjR5crFMzJgxI2wbxaPZKFOmTAkfX2ZGzJUrVwqxwcHBsO3w8HBTj5ekoaGhpmJSPMslNXsmFa8Sd+gAkAkKOgBkgoIOAJmgoANAJloaFDWzA5IGJF2VNOzuK6vo1ES0YcOGQmz9+vUtP+/GjRubblvm9aLnja6hruqY29FAZzQgKcUDndOmTQvbLly4sBC7++67w7Zz584txObPn1+ILViwoOl+XbhwIWwbDVSePn06bHv27NlCbGBgIGx75syZQuzcuXNh24sXLxZi58+fD9teunSpEKt6oLSKWS5PuPvJCp4H6DbkNmqFt1wAIBOtFnSX9N9m9pGZra2iQ0CXILdRO62+5fJddz9qZgslbTOzz939vZENGn8Z+AuBuiG3UTst3aG7+9HGx+OSfi9pVdBmk7uvrMOgEnADuY06srFu0G5mMyXd5u4Djc+3Sfond/+vWzymu3aDb7PVq1eH8Wg2SaptnaRm1HRy9ou7t3xEfLfndmo5f7REPrXEPlqi39vbG7a99957m4pJ0vLlywuxefPmFWKpGTXRLJfp06eHbaMZJocOHQrbHj16tBD7+uuvw7YnTpwoxPr7+8O2p06daiomxbNnLl++HLaNNJPbrbzlskjS7xvJNVnSf9wq4YEaIbdRS2Mu6O6+X9JfV9gXoCuQ26grpi0CQCYo6ACQCfZDr8j27dsLsRwGOt99990wvmPHjkIsp6X/dRQNiqaW/kcDjdGyfSke1Ewt3Y8GW2fOnFmIpZbHz5o1qxCbOnVq022r2P89Wo4fDWimpJ43+l2k+jvWySrcoQNAJijoAJAJCjoAZIKCDgCZoKADQCaY5VJSaiZHO2a0lDmcIjUbpdW2qI9r164VYqkDFKJZFKkZF9H2AdFrSfFMjmgp/fHjx8PHRxYvXtx0vMyMmNRMkuggitT1Dg8PN/280c93rLNZUrhDB4BMUNABIBMUdADIBAUdADLBoOgtRAOg0V7mZUWDnSybR7NSA2llBkWvXLnSVCwltUd51Ido3/GTJ+Ozt6PHp14r2mYgJdo7PbVveV9fXyGW2g892sIgei0p/buoEnfoAJAJCjoAZIKCDgCZoKADQCZGLehm9rqZHTezPSNiPWa2zcz2NT4WN0wGuhy5jdzYaEtPzexvJZ2X9Ct3f6AR+xdJp939FTNbJ2meu/9k1Bfr4MnoVWh1We4TTzwRxll63z7NnIx+w0TI7dQBF3PmzCnEli1bFrZ9+OGHC7EHH3wwbBstsf/8888Lsa+++ip8/OTJxYl39913X9h24cKFhVi0TYEkffbZZ4XY7t27w7b79+8vxKKZOpJ09uzZQmxoaChsG20TUGbmSzO5Peoduru/J+n0TeGnJW1ufL5Z0jNN9wroEuQ2cjPW99AXuXufJDU+Fv+pBOqJ3EZttX1hkZmtlbS23a8DdBq5jW4z1jv0fjPrlaTGx+RemO6+yd1XuvvKMb4W0EnkNmprrHfoWyU9J+mVxsctlfVoHFSxl3k0AMrgZy1lldupgf0oPm3atLDt0qVLC7FoQFKKtw+IlvOn9l6fOXNmIZba4/y224r3o6l91o8cOVKIHTx4MGwbPcfAwEDYNhoATe2dXvXe55Fmpi2+Kel/JP2VmR02s3/Q9WRfY2b7JK1pfA3UCrmN3Ix6h+7uzya+9b2K+wJ0FLmN3LBSFAAyQUEHgExQ0AEgExxwoWpmo0QHX6Rmz0Svx4wYdFI0Q2T27Nlh22hGS09PT9g2Wt6emhETWbx4cSG2ZMmSph8fLduXpH379hVi0UEWknTmzJlCLLouqTMzV8rgDh0AMkFBB4BMUNABIBMUdADIxKj7oVf6Yl26Z3TKhg0bCrFo8LNdUgOlO3bsKMSivk5EZfZDr1K35nZqP/RoUPPRRx8N237/+98vxO6///6wbbSfeWov8cj8+fObek4p3uN869atYdsPPvigEDt27FjYdnBwsBDrhsHPSvZDBwDUAwUdADJBQQeATFDQASATDIq2UZmByscff7wQK7NPe2oANXVQda4YFP2m1KHJ8+bNK8RWrFgRtl2zZk0h9sgjj4Rto1Wd0aDm9OnTw8dHh0xfuHAhbLtlS3Gr+rfeeits+8knnxRip06dCttGe7ozKAoA6CgKOgBkgoIOAJmgoANAJpo5U/R1MztuZntGxDaY2REz+6Tx56n2dhOoHrmN3DSzH/obkn4h6Vc3xX/u7j+tvEcZaXU5furxZfZej56DbQL+3xvKKLejPc7N4okR0cn0qdkk0VL41HL+O++8sxCbOXNmITZ16tTw8bfffnshdujQoab7NW3atLBttAVC9POS0j+zSDfMfhlp1Dt0d39P0ukO9AXoKHIbuWnlPfTnzezPjf+2Fie1AvVFbqOWxlrQfynpXkkPSeqT9LNUQzNba2a7zGzXGF8L6CRyG7U1poLu7v3uftXdr0n6d0mrbtF2k7uvdPeVY+0k0CnkNupsTIdEm1mvu984YfUHkvbcqj3GJjV4WWabgGgANbVNAAdV1yO3U4N50RL71OBjtCVANCApxQOoFy9eDNtGByxfvny5EEsNXkbx1BL9oaGhQiy1d3r0c0gNfkbxbhv8TBm1oJvZm5JWS1pgZoclrZe02swekuSSDkj6cRv7CLQFuY3cjFrQ3f3ZIPxaG/oCdBS5jdywUhQAMkFBB4BMUNABIBNjmuWC8bVx48ZCrMxhGNHMF4lZLnWRmskRHRqRmk0yZ86cQmzx4sVh22g5f09PT9j25MmThdjVq1cLsWjmixT3t8yMmNSsnjLbIpTRbdsEcIcOAJmgoANAJijoAJAJCjoAZIJB0RqKBi9TA5rRYGmZAVSMr2jQLdrbW5Jmz55diKUGCaM9ymfNmhW2nTevuOFkapuAY8eOFWL9/f2FWGo/9SVLlhRiixYtCttGWxJEWxqkVLEferdtE8AdOgBkgoIOAJmgoANAJijoAJAJCjoAZIJZLpmItgOQys1oidqyHUBnlDlsITXLJZp5Ei3xl6QFCxYUYqlZLjNmzCjEzp8/H7aNlv4fOXIkbBu5dOlSITY4OBi2LbOcP9ouocxsltSMmG47+II7dADIBAUdADJBQQeATIxa0M1sqZltN7O9Zvapmb3QiPeY2TYz29f4WFxOBnQxchu5aWZQdFjSS+7+sZnNlvSRmW2T9PeS3nH3V8xsnaR1kn7Svq7iVqoYvJyAg6JZ5XY0yBftkS7Fy+nLLLG/ePFi2DYaPIz2LU9tHRANzKbatjogmXp8ty3nL2PUO3R373P3jxufD0jaK+kuSU9L2txotlnSM+3qJNAO5DZyU+o9dDO7R9IKSTslLXL3Pun6XwxJC6vuHNAp5DZy0PQ8dDObJeltSS+6+7lm53Ca2VpJa8fWPaD9yG3koqk7dDObousJ/2t3/10j3G9mvY3v90o6Hj3W3Te5+0p3X1lFh4EqkdvISTOzXEzSa5L2uvurI761VdJzjc+fk7Sl+u4B7UNuIzc22uitmT0m6X1JuyXdGO5+Wdffa/ytpG9J+krSD9399CjPVY+h4oyUGZ2v4hT08ebuTV9EN+V2aml5JJo1Ikm9vb2F2LJly8K2DzzwQCG2atWqsG008yR1QMWFCxcKsWg5f3TAhhTPykkd0vHll18WYjt37gzb7tq1qxA7ceJE2HZoaKgQi2b6SOX+flUwK2fU3B71PXR3/5Ok1BN9r2yngG5BbiM3rBQFgExQ0AEgExR0AMhEtvuhp/YBL7M/eLTsvZNL4ctcw/r169vbGbRVFcvQo8HH1GBeNLAaDQZK0uXLlwuxaI90KR7sjK7h3Llz4eOvXr1aiB09ejRs29fXV4gdPHgwbDswMFCIDQ8Ph22jn28nBz9bwR06AGSCgg4AmaCgA0AmKOgAkAkKOgBkYsLNcikzGySHmSPRrJyNGzd2viO4pTKzXKKZIJI0ODhYiJ0+He9Y0N/fX4ilZq5Ey/Hnzp3bdB+ibQ2OHw/3O9PJkycLsbNnz4Zt9+zZU4gdOHAgbHv+/PlCLPVzjGYGZXPABQCgHijoAJAJCjoAZIKCDgCZGHU/9EpfrAv2Q48GS8tsB9CugdIyWwrs2LGj6cd3cquCblBmP/QqtSu3o0HRKVOmhG2jfcPvuOOOsG1PT08hlhrojF5v/vz5Tbe9cuVKIXbmzJnw8dE+68eOHWu6bbRNgZQeAK2TZnKbO3QAyAQFHQAyQUEHgEw0c0j0UjPbbmZ7zexTM3uhEd9gZkfM7JPGn6fa312gOuQ2ctPMStFhSS+5+8dmNlvSR2a2rfG9n7v7T9vXPaCtyG1kpfQsFzPbIukXkr4r6XyZpO+GWS7IWyuzXOqe29ES+ygmxbNnypxsX2argjLqcpDEeKh8louZ3SNphaSdjdDzZvZnM3vdzOaV7iHQJcht5KDpgm5msyS9LelFdz8n6ZeS7pX0kKQ+ST9LPG6tme0ys10V9BeoHLmNXDT1louZTZH0B0l/dPdXg+/fI+kP7v7AKM8zsf6PhI4r+5ZLTrnNWy55q+QtF7v+G3pN0t6RCW9mvSOa/UBScS9LoIuR28jNqHfoZvaYpPcl7ZZ045/wlyU9q+v/JXVJByT92N2Lx3B/87km1j+p6Lgyd+jkNuqkmdyecHu5IG+57eUC3MBeLgAwgVDQASATFHQAyAQFHQAyQUEHgExQ0AEgExR0AMgEBR0AMkFBB4BMNHPARZVOSjrY+HxB4+vccF3j5+5xfO0buV2Hn9NY5XptdbiupnK7o0v/v/HCZrvcfeW4vHgbcV0TW84/p1yvLafr4i0XAMgEBR0AMjGeBX3TOL52O3FdE1vOP6dcry2b6xq399ABANXiLRcAyETHC7qZPWlmfzGzL8xsXadfv0qNE+GPm9meEbEeM9tmZvsaH2t3YryZLTWz7Wa218w+NbMXGvHaX1s75ZLb5HX9ru2GjhZ0M5sk6d8k/Z2k70h61sy+08k+VOwNSU/eFFsn6R13Xy7pncbXdTMs6SV3v0/S30j6x8bvKYdra4vMcvsNkde11Ok79FWSvnD3/e4+JOk3kp7ucB8q4+7vSTp9U/hpSZsbn2+W9ExHO1UBd+9z948bnw9I2ivpLmVwbW2UTW6T1/W7ths6XdDvknRoxNeHG7GcLLpxoHDj48Jx7k9LzOweSSsk7VRm11ax3HM7q999rnnd6YIeHXLKNJsuZWazJL0t6UV3Pzfe/ely5HZN5JzXnS7ohyUtHfH1EklHO9yHdus3s15Janw8Ps79GRMzm6LrSf9rd/9dI5zFtbVJ7rmdxe8+97zudEH/UNJyM/u2mU2V9CNJWzvch3bbKum5xufPSdoyjn0ZEzMzSa9J2uvur474Vu2vrY1yz+3a/+4nQl53fGGRmT0l6V8lTZL0urv/c0c7UCEze1PSal3fra1f0npJ/ynpt5K+JekrST9095sHmLqamT0m6X1JuyVda4Rf1vX3G2t9be2US26T1/W7thtYKQoAmWClKABkgoIOAJmgoANAJijoAJAJCjoAZIKCDgCZoKADQCYo6ACQif8DeVeYw6MqWVkAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2933a0a5ef0>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 10; last KL=334.7369689941406; last RL=13866.205078125\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAEYFJREFUeJzt3VuMVXWWx/Hf4qYIqHgBpWSQePfBW0oco8ZLa8fxxUtsL4nGMRo63k14aMKD3U688NDaPrRppdVAG9CYeONBZ0aJimNGAdGITnlBZAQpKREVULkUrHngMCn5ry2n6uxzqva/vp+EnFOr1j7nv6nFYtf5//fe5u4CAFTfkP4eAACgHDR0AMgEDR0AMkFDB4BM0NABIBM0dADIBA0dADJBQweATDTU0M3sQjP7xMyWm9n0sgYF9DdqG1VkfT1T1MyGSvpU0gWSVktaLOlqd/+f8oYHtB61jaoa1sC2UyQtd/cVkmRmT0u6WFJh0ZsZ1xlAU7m7lfAy1DYGnHpqu5GPXNokrerx9epaDKg6ahuV1MgRevS/RXKUYmZTJU1t4H2AVqO2UUmNNPTVkib2+PowSWt2T3L3WZJmSfxaisqgtlFJjXzksljSUWY22cxGSLpK0vxyhgX0K2obldTnI3R37zazWyX9h6Shkp5w949KGxnQT6htVFWfly326c34tRRNVtIql16jttFszV7lAgAYQGjoAJAJGjoAZIKGDgCZoKEDQCZo6ACQCRo6AGSChg4AmaChA0AmaOgAkAkaOgBkgoYOAJmgoQNAJmjoAJCJRu5YBACVYJZeeTaKFdmxY0eZw2kajtABIBM0dADIBA0dADJBQweATDQ0KWpmKyVtlLRdUre7t5cxqNy1tbUlsXnz5iWxl19+Odw+yv3yyy8bHxj+32Cs7aJJwiFD0uO+otwRI0YksWhCsehexlFud3d3mBuJxipJw4alra43Yygy0CZLy1jlcq67ryvhdYCBhtpGpfCRCwBkotGG7pL+08zeNbOpZQwIGCCobVROox+5nOHua8xsnKRXzOxjd1/YM6H2j4F/EKgaahuV09ARuruvqT12SXpe0pQgZ5a7tw+GSSXkg9pGFfX5CN3MRkka4u4ba89/K+nfShtZxqZPn57ETjnllCQ2ZUrSQyRJp512WhK7/fbbw9xVq1b1cnTIqbZ7s3Jl+PDhYe4+++yTxMaNGxfm7rvvvkksWvmy1157hdtH4917773D3O+++y6Jffvtt3Xnbtq0KczdunVrGK+CRj5yGS/p+doPYJikee7+76WMCuhf1DYqqc8N3d1XSDqxxLEAAwK1japi2SIAZIKGDgCZsKLTX5vyZmate7MB7Iorrkhi0an7Y8aMCbd/8MEHk9jPP/8c5t54441J7Jtvvglzo/fbuHFjmNvZ2RnG+5u713+R6xIN1NoeOnRoGB85cmQSK6q3sWPHJrHx48eHuUcccUQSO/LII5PYpEmTwu2jCdDotH0prs1ly5aFuW+++WYS+/TTT8PcDRs2JLEtW7aEua3sn/XUNkfoAJAJGjoAZIKGDgCZoKEDQCZo6ACQiTKuh45eeuaZZxra/rzzzkti0WoWSXr11VeT2IEHHlj3e02bNi2MP/TQQ3W/Blqj0TvbF63YiFaeHHvssWFue3t6WZtolcy2bdvC7VevXl33uCZPnpzEDjvssDA3WsFTxg0uBhqO0AEgEzR0AMgEDR0AMkFDB4BMMClaQdGp+/fff3+Y+9hjjyWxSy+9NMyNJjrnz5/fy9Ghv0STfEUTfN3d3Ums6DIBEyZMSGJHH310mBtdU33x4sVJ7PPPPw+3j8Z7/PHHh7nRZQaKLl+wefPmumKStH379iTWylP8G8EROgBkgoYOAJmgoQNAJmjoAJCJPTZ0M3vCzLrM7MMesQPM7BUz+6z2mJ4KBgxw1DZyU88ql9mS/irpHz1i0yUtcPeZZja99vUfyh8eGvX9998nsZtuuinMffTRR5PYihUrSh/TADJbmdd20eqMaJVLtEJFko477rgkVnT5iJUrVyaxpUuXJrF169aF20eve8ghh9Sd+9NPP4W53333XRLbunVrmFuVFS2RPR6hu/tCSet3C18saU7t+RxJl5Q8LqDpqG3kpq+foY93905Jqj2OK29IQL+itlFZTT+xyMymSpra7PcBWo3axkDT1yP0tWZ2qCTVHruKEt19lru3u3t6XU1g4KG2UVl9PUKfL+k6STNrjy+WNiKUaubMmUnsmGOOCXOvv/76Zg+nCrKq7d5M8I0ePTqMH3DAAUmsaPJx+fLlSezrr79OYtHp9VJ8SYFTTjklzG1ra0tin3zySZgbTYoWjSHrSVEze0rSf0s6xsxWm9kN2lnsF5jZZ5IuqH0NVAq1jdzs8Qjd3a8u+NZvSh4L0FLUNnLDmaIAkAkaOgBkgoYOAJngBheZuPzyy8P4zTffnMSKbobx/vvvlzomDFxmlsRGjRoV5o4fPz6JFZ02Hxk2LG0z0coZSTrjjDOSWNGqrGiVyrJly8LcaJVL0c0/sl7lAgCoBho6AGSChg4AmaChA0AmmBStoIkTJyaxBx54IMz94osvktgjjzxS+phQLdHEX9H10CMjR46sOx6doj9hwoRw++ja6/vvv3+YG03iL1y4MMzdvHlzEqvy5GcRjtABIBM0dADIBA0dADJBQweATDApWkHz5s1LYj///HOYe/755yexrq7CezZgkIgmBDs7O8Pczz//PIntu+++YW50BunBBx+cxCZPnhxuH034F3n77beTWLQIQCq+9nkkOou2yECbWOUIHQAyQUMHgEzQ0AEgEzR0AMhEPfcUfcLMuszswx6xP5nZV2b2fu3PRc0dJlA+ahu5qWeVy2xJf5X0j93if3H3P5c+ooy0t7eH8eju6tHp0UXOPPPMJHbhhReGuWvWrKn7dQeh2RqktR2tzli3bl2Y+8YbbySxsWPH1v1e0fXQx4wZE+ZG12SPrmUuSS+88EIS27BhQ5jbm9Uo0SqXgbaapcgej9DdfaGk9S0YC9BS1DZy08hn6Lea2Qe1X1vr/+8aGPiobVRSXxv63yQdIekkSZ2S4kv9STKzqWa2xMyW9PG9gFaitlFZfWro7r7W3be7+w5Jf5c05VdyZ7l7u7vHHygDAwi1jSrr06n/Znaou+86T/hSSR/+Wv5gcNdddyWx2267LcyNJpR6c7rx008/ncQWLVpU93tt27YtzI2uGd3d3V33uHIwmGu7qC6i0+lXrVoV5u69995JLLrJ9IEHHhhuP3To0CT2+uuvh7kffPBBEivah95Maka5VZkU3WNDN7OnJJ0j6SAzWy3pj5LOMbOTJLmklZJ+38QxAk1BbSM3e2zo7n51EH68CWMBWoraRm44UxQAMkFDB4BM0NABIBPc4KKXrrnmmjB+1VVXJbGzzjorzI1m+O++++4kdu6554bbX3nllUnsiiuuCHN37NiRxObOnRvmRqtyNm7cGOai2qJVG0Urmn744YckFp3OX/S6hxxySBI76aSTwu3Xr09P3H3yySfD3OiSAFG9DyYcoQNAJmjoAJAJGjoAZIKGDgCZYFK0l3788ccwPmPGjCT28ccfh7nHHntsEjv99NOT2EsvvRRuf++99/7aEH8hmuhasoRrSQ12vZkU3b59exIbOXJkmBvFTz311CR28MEHh9t3dHQksaJ63bp1axgfzDhCB4BM0NABIBM0dADIBA0dADJBQweATLDKpZeef/75hl/jlltuSWLRCoMbbrgh3L6rq6vhMQC7KzptfsiQ9LgvuhGFFN/M4uyzz05imzZtCrdfsGBBElu3bl2YO9hP849whA4AmaChA0AmaOgAkIk9NnQzm2hmr5lZh5l9ZGZ31OIHmNkrZvZZ7TG9GzEwgFHbyE09k6Ldkqa5+1IzGyPpXTN7RdK/Slrg7jPNbLqk6ZL+0LyhVs8JJ5wQxi+77LIkdtdddyUxJj+bjtruwczC+IgRI5LY6NGjw9wTTzwxiW3bti2JrVy5Mtx+/vz5SWzLli1hLlJ7PEJ39053X1p7vlFSh6Q2SRdLmlNLmyPpkmYNEmgGahu56dVn6GZ2uKSTJb0jaby7d0o7/2FIGlf24IBWobaRg7rXoZvZaEnPSrrT3TcU/XoWbDdV0tS+DQ9oPmobuajrCN3Mhmtnwc919+dq4bVmdmjt+4dKCj/wdfdZ7t7u7u1lDBgoE7WNnNSzysUkPS6pw90f7PGt+ZKuqz2/TtKL5Q8PaB5qG7mp5yOXMyRdK2mZmb1fi82QNFPSM2Z2g6QvJf2uOUOsrjvvvDOMR3dBf+utt5o9HKSo7R6KTuffa6+9kth+++0X5k6YMCGJbd68OYm999574fYrVqxIYpziX789NnR3/y9JRR8q/qbc4QCtQ20jN5wpCgCZoKEDQCZo6ACQCa6HXpJrr702iV1zzTVh7sMPP5zEli5dWvqYgKI19dEE6LBhcTuITv2PrnsuSfvvv38SW79+fRJbtGhRuP1PP/0UxlEfjtABIBM0dADIBA0dADJBQweATNDQASATrHIpybRp05LYhg0bwtx77rkniXF6MxoVrWgpWuUSxYcPHx7mjho1KolFl6+QpCFD0mPE6HT+aOULGscROgBkgoYOAJmgoQNAJmjoAJAJJkV7qa2tLYxPmjQpid13331hbldXeAMcoCGNTooWXQ/9oIMOSmLRdc+LXnfLli1JrOjfwPbt28M46sMROgBkgoYOAJmgoQNAJuq5SfREM3vNzDrM7CMzu6MW/5OZfWVm79f+XNT84QLlobaRm3omRbslTXP3pWY2RtK7ZvZK7Xt/cfc/N294QFNR28hKPTeJ7pTUWXu+0cw6JMVLPQaBr776KoyPHTu2xSNBo3KrbXdPYkWrXKLcolUukaJVKtEp/dFNK77++utw++7u7rrHgFSvPkM3s8MlnSzpnVroVjP7wMyeMDM6GiqL2kYO6m7oZjZa0rOS7nT3DZL+JukISSdp51HOAwXbTTWzJWa2pITxAqWjtpGLuhq6mQ3XzoKf6+7PSZK7r3X37e6+Q9LfJU2JtnX3We7e7u7tZQ0aKAu1jZzUs8rFJD0uqcPdH+wRP7RH2qWSPix/eEDzUNvIjUWTI79IMDtT0puSlknaddHuGZKu1s5fSV3SSkm/r00y/dpr/fqbAQ1y93gWMDCYazuaLI2uZS5JI0eOTGIjRoyoOze61v/atWvD7bkvQLF6anuPDb1MVSt6VE9vGnqZqlbbNPTqqae2OVMUADJBQweATNDQASATNHQAyASTosgKk6LIFZOiADCI0NABIBM0dADIBA0dADJRzw0uyrRO0v/Wnh9U+zo37Ff/mdSP772rtqvw99RXue5bFfarrtpu6SqXX7yx2ZIcr1LHfg1uOf895bpvOe0XH7kAQCZo6ACQif5s6LP68b2bif0a3HL+e8p137LZr377DB0AUC4+cgGATLS8oZvZhWb2iZktN7PprX7/MtXuCN9lZh/2iB1gZq+Y2We1x8rdMd7MJprZa2bWYWYfmdkdtXjl962Zcqlt6rp6+7ZLSxu6mQ2V9LCkf5F0vKSrzez4Vo6hZLMlXbhbbLqkBe5+lKQFta+rplvSNHc/TtI/S7ql9nPKYd+aIrPani3qupJafYQ+RdJyd1/h7lslPS3p4haPoTTuvlDS+t3CF0uaU3s+R9IlLR1UCdy9092X1p5vlNQhqU0Z7FsTZVPb1HX19m2XVjf0Nkmreny9uhbLyfhdNxSuPY7r5/E0xMwOl3SypHeU2b6VLPfazupnn2tdt7qhR9fzZZnNAGVmoyU9K+lOd9/Q3+MZ4Kjtisi5rlvd0FdLmtjj68MkrWnxGJptrZkdKkm1x65+Hk+fmNlw7Sz6ue7+XC2cxb41Se61ncXPPve6bnVDXyzpKDObbGYjJF0laX6Lx9Bs8yVdV3t+naQX+3EsfWJmJulxSR3u/mCPb1V+35oo99qu/M9+MNR1y08sMrOLJD0kaaikJ9z93pYOoERm9pSkc7Tzam1rJf1R0guSnpH0T5K+lPQ7d999gmlAM7MzJb0paZmkHbXwDO38vLHS+9ZMudQ2dV29fduFM0UBIBOcKQoAmaChA0AmaOgAkAkaOgBkgoYOAJmgoQNAJmjoAJAJGjoAZOL/ANlywWOHz2B3AAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2932eca9518>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 20; last KL=342.6477355957031; last RL=13823.6806640625\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAEydJREFUeJzt3XtsleWWx/Hf4qqCgIAiYrmIZsQYB7XoxDGjk6NGDfGSeMzxH5SMwctoMBojQZNDxkyiRj3zx4wiE42cxKOeiA7mGEeImugYMeIlKlMcELlJ04qoFLQods0fbDIdnvXa3Xbv3b4P309C2q4+u/t529XF2/3czN0FACi/IQPdAQBAbVDQASATFHQAyAQFHQAyQUEHgExQ0AEgExR0AMgEBR0AMtGvgm5ml5rZ52a20cwW1apTwEAjt1FG1teVomY2VNL/SLpY0nZJ70u6zt3/u3bdAxqP3EZZDevHY8+RtNHdN0mSmT0n6UpJhUlvZuwzgLpyd6vBlyG3MehUk9v9eclliqRt3T7eXokBZUduo5T6c4ce/W+R3KWY2QJJC/rxPECjkdsopf4U9O2Smrp9fKKkHYc2cvdlkpZJ/FmK0iC3UUr9ecnlfUmnmNkMMxsh6XeSXq5Nt4ABRW6jlPp8h+7u+83sNkmvSRoq6Sl3X1ezngEDhNxGWfV52mKfnow/S1FnNZrl0mvkNuqt3rNcAACDCAUdADJBQQeATFDQASATFHQAyAQFHQAyQUEHgExQ0AEgExR0AMgEBR0AMkFBB4BMUNABIBMUdADIBAUdADLRnxOL0M3pp5+exN54442wbUdHRxK78847k9jKlSv73zGgSmbx7qxDhqT3fcOGxaVj5MiRSWz48OFJbOjQoeHjo/gvv/wStv3pp5+S2M8//xy27erqqipWFC/qQ7T9eCO3JD8Ud+gAkAkKOgBkgoIOAJmgoANAJvo1KGpmmyV1SPpF0n53b65Fp8po3rx5SWzixIlh2yj+wgsvJLEHHnggfPySJUuSWNGgDfqmjLkdDWoWDT6OGDEiiY0bNy5sO3Xq1CQ2Z86csO20adOS2Pjx45NYNFAqxYOtO3fuDNtG8S+//DJsu23btiT27bffhm137dqVxL7//vuwbTQwu3///rBtFK/1AGotZrn8vbvH33Gg3MhtlAovuQBAJvpb0F3SKjP7wMwW1KJDwCBBbqN0+vuSy9+6+w4zO07SajNb7+5vdW9Q+WXgFwJlQ26jdPp1h+7uOypv2yW9JOmcoM0yd28uw6AScBC5jTLq8x26mY2SNMTdOyrvXyLpn2rWs5I59thj+/X4aDbCfffdF7adMmVKErvpppvCtkVLoVFssOd20RL9KIeOPPLIsG2UQ+eck/yfJUk677zzkthpp50Wto1mcB1xxBFJrGj2TbTNwL59+8K20SyXzZs3h23XrVuXxDZu3Bi2/eKLL6rqlyTt3bs3ie3ZsydsG81EG0yzXCZJeqmSXMMk/cnd/7MmvQIGFrmNUupzQXf3TZL+uoZ9AQYFchtlxbRFAMgEBR0AMsF+6DVy4403JrHjjjsubHvZZZf167luuOGGJFY0uHLLLbcksWi5MsqjaFA0Wk4/ZsyYsO2MGTOSWLSnvySdfPLJSWzChAlh22gQPlpi39nZGT4+Gnwsut7oa0T7sUvxtgbRYK0UDyQX7f8eKepvI3CHDgCZoKADQCYo6ACQCQo6AGSCgg4AmWCWS41Ey3qj2SiS9M477ySxaCbBd999Fz5+7NixSWz+/Plh2y1btiSxpUuXhm3b29vDOMohmiFSNOsjmuFRdOjEjz/+mMSivJKkr776Kolt3bo1iRUdLtGbmTqjR48O45G2trYkVvT71dHRkcSi74EUb0tQdMBFrZf5R7hDB4BMUNABIBMUdADIBAUdADLBoGgdff3112H8wQcfTGL33HNPEps3b174+Oeffz6JNTU1hW2XLFmSxIqWPC9evDiMoxy6urqqbhvtR96bfceLBjVbWlqS2KZNm5JY0SBjNNBZtM1AtPd60bL7aFA0iknS7t27k1hRf6OtDqIJEo3CHToAZIKCDgCZoKADQCYo6ACQiR4Lupk9ZWbtZvZZt9h4M1ttZhsqb4+pbzeB2iO3kRvraTmqmf2dpD2S/ujup1diD0na5e4PmNkiSce4ezpNI/1a9V/7ehiIlv6vWrUqbDtnzpwkVjRif8011ySxV199tZe9G1juXvXpAmXN7aKZHNHspSlTpoRtzzrrrCR25plnhm2jfGttbQ3brl+/vqq20TYFUjyjJTqcQoqvtyi3d+zYUVVMirfA2Lt3b9g2OiymXkv/q8ntHu/Q3f0tSbsOCV8paXnl/eWSrup174ABRm4jN319DX2Su7dKUuVtfNYaUD7kNkqr7guLzGyBpAX1fh6g0chtDDZ9vUNvM7PJklR5W7jvqrsvc/dmd2/u43MBjURuo7T6eof+sqTrJT1QebuyZj1Cj77//vsktnDhwrDt6tWrk9ioUaPCts3NaV0q26BoDZQ2t6NBt6KBuP5uE1A0UDlz5swkFi3RHzFiRPj4o446qqrnl6TOzs4kVrTHefQ7Ey3xL/q6RQOdvfk+NkI10xaflfSupL8ys+1m9g86kOwXm9kGSRdXPgZKhdxGbnq8Q3f36wo+9Zsa9wVoKHIbuWGlKABkgoIOAJmgoANAJjjgIhNr1qwJ49Gof9Esl5deeqmmfUJ9FC39j+LR0nQpPswiOqxBkkaOHJnEopkrknTSSSclsWimTdHMlWjWSDRDRZK+/PLLJFZ0DVG8aOZKdEBF0aEVUX/7u8S/P7hDB4BMUNABIBMUdADIBAUdADLBoCiQiWgwrmjgb8+ePUmsaC/xaFA02iNdivczLxoAjUR9GDYsLlNtbW1JLNojXSreaiAykIOa/cUdOgBkgoIOAJmgoANAJijoAJAJBkUzcdFFF4XxaJCqpaUlbFu0PzQGl6JBu2g1Y7QiVJK+/vrrJLZt27awbbRHedFgayQaVC0avCzqbyTK7eOPPz5sG11v0Yrb3gyKDrYBVO7QASATFHQAyAQFHQAyQUEHgExUc6boU2bWbmafdYstMbOvzOzjyr/L69tNoPbIbeSmmlkuT0v6V0l/PCT+B3d/uOY9Qp9cffXVYTyaTfDww/GPbevWrTXtUwk8rRLmdm9muRQt5//mm2+S2Pr168O2nZ2dSWzXrl1h28mTJyex4cOHJ7Gi7QCimSdHHnlk2Dba1z/aj12Kr/fzzz+vug+DbTZLkR7v0N39LUnxTw8oMXIbuenPa+i3mdknlT9bj6lZj4CBR26jlPpa0B+XNFPSbEmtkh4pamhmC8xsrZmt7eNzAY1EbqO0+lTQ3b3N3X9x9y5J/y7pnF9pu8zdm929ua+dBBqF3EaZ9Wnpv5lNdvfWyodXS/rs19qjb4oO4n388ceT2Ny5c8O277//fhJ78cUX+9exjJU5t6OBu6JDk3/44YckFh0oLsXL5osOTW5vb09i0XL+vXv3ho+PthmYPn162Pbss89OYk1NTWHb6KDpTz75JGwbXW9ZBkV7LOhm9qykCyVNNLPtkn4v6UIzmy3JJW2WdFMd+wjUBbmN3PRY0N39uiD8ZB36AjQUuY3csFIUADJBQQeATFDQASATHHAxSMyfPz+J3X777WHb2bNnJ7HXXnstbDtv3rwkFo344/DS1dWVxIoOrYi2D+jNjJhvv/02iRXNcokOwyiaqTNr1qwkdsopp4RtI++9914Yj7ZAKMssF+7QASATFHQAyAQFHQAyQUEHgEwwKFpHw4bF394HH3wwiS1YsCCJRfs9S9KiRYuS2COPxHtIFS3RRnkVnVY/ZEh6fxbtRS7FS+xHjx4dto321I+W80vxgHu0F3nRPu3RcxUNwEbfh6OPPjpsGw0CF+2zHrVlUBQA0FAUdADIBAUdADJBQQeATFDQASATzHKpkWhGy2233Ra2XbhwYRKLlkxfe+214ePXrFmTxJjNcvgomuUS5WA0m0WSjj322CQ2adKksG00UyZazi9JnZ2dVcWimSRSPFOnaObKmDFjkljRzJWov3v27Anblvl3iTt0AMgEBR0AMkFBB4BM9FjQzazJzN40sxYzW2dmCyvx8Wa22sw2VN4eU//uArVDbiM31QyK7pd0l7t/aGZHS/rAzFZLukHS6+7+gJktkrRI0j316+rgdv/99yexW2+9NWx77733JrFnnnkmie3atSt8/Lhx45JY0eBXZPfu3WG8o6Oj6q+RiVLmdi0GRSdOnJjETjjhhLBtNIBZNHAYDTRGWwoUXcO0adOS2Lnnnhu2nTFjRhKLBlUlaceOHUlsy5YtYduifeHLoMc7dHdvdfcPK+93SGqRNEXSlZKWV5otl3RVvToJ1AO5jdz06jV0M5su6UxJ70ma5O6t0oFfDEnH1bpzQKOQ28hB1fPQzWy0pBWS7nD33UV/MgWPWyAp3UoQGCTIbeSiqjt0MxuuAwn/jLu/WAm3mdnkyucnS2qPHuvuy9y92d2ba9FhoJbIbeSkmlkuJulJSS3u/mi3T70s6frK+9dLWln77gH1Q24jN9bTxu1mdr6ktyV9KungcPdiHXit8c+SpkraKum37h5Py/i/r1WOXeIrJk+enMSeeOKJsO0FF1yQxIqWLH/88cdJLDoE4MQTTwwfP3369CRW9HOMDgxYtmxZ2Pbmm28O42Xi7tW9XqLy5nbRwSnRjJaimSuzZs1KYqeeemrYNsqhKF+leAuLyIQJE8L4nDlzklhzc/wHUDRTp709/GNKjz32WBJ77rnnwrbRgRqD4YCLanK7x9fQ3f2/JBV9od/0tlPAYEFuIzesFAWATFDQASATFHQAyAT7of+Khx9+OInNnTu331939uzZ/Xr8K6+8ksTuvvvusO1DDz2UxKIBIgxO0Zz4ouXtQ4cOrSomSaNGjUpiTU1NYdtoED7afkKKB1Cj2NixY8PHRxMJitYFRAOgK1asCNuuWrUqiRXthz4YBkD7ijt0AMgEBR0AMkFBB4BMUNABIBMUdADIBLNcfsXatWuTWNGI+yWXXJLEig6MiDbxf/vtt5PYu+++Gz5+6dKlSWzz5s1h2yuuuCKMo7x6Mwtj3759YXznzp1JrGjZ/syZM5NY0YyYSZMmJbHhw4cnsaIDMqI+bNiwIWy7cmW6xU7RLJfW1tYkVuaDLIpwhw4AmaCgA0AmKOgAkAkKOgBkosf90Gv6ZCXbDx3l05v90GupkbldtJx/5MiRSSzaI12K9xKfOnVq2PaMM85IYjNmzAjbRvucRzWmra0tfPxHH32UxD744IOw7ZYtW5LYDz/8ELYtGoQtk2pymzt0AMgEBR0AMkFBB4BMVHNIdJOZvWlmLWa2zswWVuJLzOwrM/u48u/y+ncXqB1yG7mpZqXofkl3ufuHZna0pA/MbHXlc39w93TTcKAcyG1kpZpDolsltVbe7zCzFklT6t0xoN7KmttFMzZ+/PHHJNbZ2Rm2jU6237RpU9g22paiaAuMaEZLV1dXEiu6hihe5gMnGq1Xr6Gb2XRJZ0p6rxK6zcw+MbOnzOyYGvcNaBhyGzmouqCb2WhJKyTd4e67JT0uaaak2Tpwl/NIweMWmNlaM0t3ugIGAXIbuahqYZGZDZf0F0mvufujweenS/qLu5/ew9fhbyfUVW8XFuWU29HLIEUvjUTnkvbmrFJecmm8miwssgM/uScltXRPeDOb3K3Z1ZI+60sngYFCbiM3Pd6hm9n5kt6W9Kmkg//VLpZ0nQ78SeqSNku6qTLI9Gtfi/9qUVe9uUMnt1Em1eQ2e7kgK4fDXi44PLGXCwAcRijoAJAJCjoAZIKCDgCZoKADQCYo6ACQCQo6AGSCgg4AmaCgA0AmqjngopZ2Sjp4VPfEyse54boGzrQBfO6DuV2G71Nf5XptZbiuqnK7oUv//98Tm6119+YBefI64roObzl/n3K9tpyui5dcACATFHQAyMRAFvRlA/jc9cR1Hd5y/j7lem3ZXNeAvYYOAKgtXnIBgEw0vKCb2aVm9rmZbTSzRY1+/lqqnAjfbmafdYuNN7PVZrah8rZ0J8abWZOZvWlmLWa2zswWVuKlv7Z6yiW3yevyXdtBDS3oZjZU0r9JukzSaZKuM7PTGtmHGnta0qWHxBZJet3dT5H0euXjstkv6S53nyXpbyT9Y+XnlMO11UVmuf20yOtSavQd+jmSNrr7Jnf/SdJzkq5scB9qxt3fkrTrkPCVkpZX3l8u6aqGdqoG3L3V3T+svN8hqUXSFGVwbXWUTW6T1+W7toMaXdCnSNrW7ePtlVhOJh08ULjy9rgB7k+/mNl0SWdKek+ZXVuN5Z7bWf3sc83rRhf06JBTptkMUmY2WtIKSXe4++6B7s8gR26XRM553eiCvl1SU7ePT5S0o8F9qLc2M5ssSZW37QPcnz4xs+E6kPTPuPuLlXAW11Ynued2Fj/73PO60QX9fUmnmNkMMxsh6XeSXm5wH+rtZUnXV96/XtLKAexLn5iZSXpSUou7P9rtU6W/tjrKPbdL/7M/HPK64QuLzOxySf8iaaikp9z9nxvagRoys2clXagDu7W1Sfq9pP+Q9GdJUyVtlfRbdz90gGlQM7PzJb0t6VNJXZXwYh14vbHU11ZPueQ2eV2+azuIlaIAkAlWigJAJijoAJAJCjoAZIKCDgCZoKADQCYo6ACQCQo6AGSCgg4AmfhfkOqtlYoeWRAAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2930ac3c048>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 30; last KL=326.5901794433594; last RL=13260.6181640625\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAE7xJREFUeJzt3WuM1eW1x/HfcgSpXKyUgshtqBeEHiMoVdLaBNs0FmNjSaOpL6wnNcUXtrGJTUrqizansfWFtiepx0ZOtVhjbTRtj2hTKyGtVkVFiOLAcBkpyMAoIiqgIg6s84JNMuVZf9gz+zL7/8z3k5CZWbP23s8fnln8Zz83c3cBAMrvpMFuAACgPijoAJAJCjoAZIKCDgCZoKADQCYo6ACQCQo6AGSCgg4AmaipoJvZV81so5l1mdniejUKGGz0bZSRDXSlqJm1Sdok6SuSuiWtknStu6+vX/OA5qNvo6xOruGxF0vqcvctkmRmf5B0laTCTm9m7DOAhnJ3q8PT0LfRcqrp27W85TJJ0vY+X3dXYkDZ0bdRSrXcoUf/WyR3KWa2SNKiGl4HaDb6NkqploLeLWlKn68nS9p5bJK7L5G0ROLXUpQGfRulVMtbLqsknWNm081suKRvSlpWn2YBg4q+jVIa8B26u/ea2Xcl/U1Sm6T73H1d3VoGDBL6NspqwNMWB/Ri/FqKBqvTLJd+o2+j0Ro9ywUA0EIo6ACQCQo6AGSCgg4AmaCgA0AmKOgAkAkKOgBkgoIOAJmoZS8XAGgpZs1dV9bMhZnV4A4dADJBQQeATFDQASATFHQAyAQFHQAywSwXAC0jmqXS1tYW5g4bNiyJDR8+PMwdPXp0Ehs1alSY+/777yexffv2hbkfffRREvv444/D3EOHDiWxes+S4Q4dADJBQQeATFDQASATFHQAyERNg6JmtlXSPkmHJPW6+9x6NApHjBkzJozPmTMniS1YsCDMvfnmm5PY+vXrw9wLL7wwiT388MNh7q233prEurq6wtwyom/XTzSoGQ1oStKpp56axMaOHRvmnn322Uls4sSJYe4nP/nJJFb087V///4ktn379jC3o6Mjie3ZsyfMjQZWP/jggzB3oIOl9Zjlcpm7767D8wCthr6NUuEtFwDIRK0F3SU9aWarzWxRPRoEtAj6Nkqn1rdcvuDuO81svKTlZrbB3Z/um1D5YeAHAmVD30bp1HSH7u47Kx93SfqzpIuDnCXuPpdBJZQJfRtlZAMdTTWzkZJOcvd9lc+XS/ovd3/iOI9prd3gW8jcuWlNeOyxx8Lc8ePH1/Ravb29YfzAgQNJrGh59NatW5NY0UybTZs2Vd+4Grl7zScc0Lf/XX+W448YMSKJTZgwIYnNnj07fPy5556bxGbMmBHmRsv5i2auRO0qmmkTLd3fsWNHmPvss88msaL+vmHDhqqfN/oZraZv1/KWywRJf678Y58s6ffH6/BAidC3UUoDLujuvkXSBXVsC9AS6NsoK6YtAkAmKOgAkAn2Qx8En/jEJ5LYnXfemcSKBj93704XLxYt0b/xxhuT2N133x3mrly5Mok99NBDYW57e3sSu+WWW6puA1pPNPgpxXuMjxs3LsydNWtWElu4cGESiyYBSHGfL9qLPNq3vGgv8mhQNBpUleLtB4oGUN97770kFv18SvXf+zzCHToAZIKCDgCZoKADQCYo6ACQCQo6AGSCWS6DYPr06Uns0ksvTWJvvfVW+PhvfOMbSSxagixJ3/72t5PYihUrwtzu7u4wjvycdFJ6LxfNZpGkSZMmJbEvfvGLYe78+fOT2EUXXZTEogMnpHjmymuvvRbmbtmyJYlFs06kePuAz3zmM2Huhx9+mMSKfhajpftvvPFG1c9bNLNooLhDB4BMUNABIBMUdADIBAUdADLBoOggiJZHR37729+G8WgA9LrrrgtzoyXLV155ZZgbDWj1R2dnZ02PR/0VDbqdcsopSezMM88Mc+fNm5fEPve5z4W5U6dOTWLRcvxob3BJeu6555JYR0dHmLt3794kFm2rIcXL+YsGZk8+OS2L+/fvD3OjQdwPPvggzI32kK/3dgDcoQNAJijoAJAJCjoAZIKCDgCZOGFBN7P7zGyXmXX0iY01s+Vmtrny8fTGNhOoP/o2clPNLJelku6S9Ls+scWSVrj77Wa2uPL1D+vfvDwVzTI51uOPPx7GzzvvvCT2s5/9LMyNRta/853vVPX6x3Pbbbclsbvuuqvm522ypcqob0czWooOZohmeEQzVKR4ifynP/3pMPfw4cNJLJrR8vzzz4ePX7NmTRJ75513wtzTTjstiU2ZMiXMjeIjR44Mc/fs2ZPE3n333TB327ZtVeceOHAgiTV9lou7Py3p2Cu8StL9lc/vl/T1urYKaAL6NnIz0PfQJ7h7jyRVPsZnpQHlQ99GaTV8YZGZLZK0qNGvAzQbfRutZqB36G+a2URJqnzcVZTo7kvcfa67x6fCAq2Fvo3SGugd+jJJ10u6vfLx0bq1aAg4dOhQVXl33313GN+8eXMSK1q23R/RAM2DDz4Y5t5zzz1JrLe3t+Y2tIDS9u1oj/Noib8UDygW9aFoAPWjjz4Kczdu3FhVLNrLXIoHFIuW6J9//vlJ7POf/3yYG11b0d7p0RYWRecNrFu3Lont3r07zI32Q48GkWtRzbTFhyStlDTDzLrN7AYd6exfMbPNkr5S+RooFfo2cnPCO3R3v7bgW1+uc1uApqJvIzesFAWATFDQASATFHQAyITVe+npcV/MrHkv1sKiAy5effXVhrxWtDF/T09PmPvTn/40iRXNcmlV7l7fY9Sr1Ap9OzqYYfTo0WHujBkzktjMmTPD3HPOOSeJjR8fr7fat29fEtu0aVMS27lzZ/j46MCISy65JMz97Gc/W3W7oj7/4osvhrmrVq1KYkWzcqKfr4MHD4a5tc4Cq6Zvc4cOAJmgoANAJijoAJAJCjoAZKLhm3MNZWeccUYYv/baovUs1YmWLN96661hbrRkee3atTW9PsqjaJuJ6GT6aL9uKR7oLNomYNy4cUksGpgtWs5/9tlnJ7Hp06eHudE2AS+//HKYGw10vvLKK2HuG2+8kcSiwVop/vut93L+/uAOHQAyQUEHgExQ0AEgExR0AMgEg6L9NHHixDD+q1/9KonNnj07zC0a5KnWt771rSRWdKA0ho5o1Xd/BkWLBv6iQb4RI0aEudHqzegw5osuuih8/KhRo5JYd3d3mBsNdBYNikYrsXftis8u+fjjj5NYM1fU14I7dADIBAUdADJBQQeATFDQASAT1Zwpep+Z7TKzjj6xn5jZDjN7ufLnisY2E6g/+jZyU80sl6WS7pL0u2Piv3T3O+reohYyefLkJLZs2bIw94ILLkhiRSej/+Mf/0hi8+fPr7pdXV1dVefiuJYqo74dzcQwq357+OHDh4fxaI/x9vb2MHfatGlJLFr6XzRrZMeOHUls+/btYe7zzz+fxIqW8+/duzeJ1bo/eSs64R26uz8taU8T2gI0FX0buanlPfTvmtnayq+tp9etRcDgo2+jlAZa0H8t6SxJsyX1SLqzKNHMFpnZS2b20gBfC2gm+jZKa0AF3d3fdPdD7n5Y0v9Kuvg4uUvcfa67zx1oI4FmoW+jzAa09N/MJrr70VNXF0rqOF5+q7v88svD+M9//vMkFg1+SvFhyn/961/D3Gjgpj+HRJ9//vlJbMOGDVU/HsVy69vRwdGSNGbMmCRWtMd5dEj0pEmTqn69aCl90Z7hr7/+ehIrGuh87bXXkli0d7tUvAVCbk5Y0M3sIUnzJY0zs25JP5Y038xmS3JJWyXd2MA2Ag1B30ZuTljQ3T06XufeBrQFaCr6NnLDSlEAyAQFHQAyQUEHgEwMuQMuzjjjjCQWHU4hSWeddVYSu+mmm8Lce+65J4kVLW+eNWvW8Zp4QgsWLEhijzzySE3PifKLZphEy+6l+CCKCy+8MMyNDp3YsydeYPvuu+8msba2tiRWdEDG1q1bk9jGjRvD3Gg5f9HsmbIcUFEr7tABIBMUdADIBAUdADJBQQeATGQ7KPqpT30qjC9fvjyJRYOfkvTAAw8ksWjwUxo6gy4YfNEgoxQPgJ533nlh7sUXp1vUFA2gvvfee0ls8+bNYW40UBn9fEUDrZK0ZcuWJLZ///4wN9rPfKj/HHKHDgCZoKADQCYo6ACQCQo6AGSCgg4Amch2lss111wTxqNl9+vXrw9zf/CDHySxeoyiv/3220ks2qy/aPYNhraiZfPTpk1LYl/60pfC3Pb29iRWtGy+oyM942Pt2rVVt23YsGFJrGgW2kknpfeY0QEZUnxoBbNcAABZoKADQCYo6ACQiRMWdDObYmZ/N7NOM1tnZjdX4mPNbLmZba58PL3xzQXqh76N3FQzKNor6RZ3X2NmoyWtNrPlkv5T0gp3v93MFktaLOmHjWtq/8yZM6fq3NWrV4fx3bt316s5/yYaEOrPAOiuXbvq2ZyhrOX7tpklsdNPj/9/ueSSS5LYzJkzw9xTTz01iUV7kUtSZ2dnEov2PZek6dOnJ7Hx48cnsehcAknq6elJYtHWA4id8A7d3XvcfU3l832SOiVNknSVpPsrafdL+nqjGgk0An0buenXe+hm1i5pjqQXJE1w9x7pyA+GpPS/YaAk6NvIQdXz0M1slKQ/Svq+u++NfhUseNwiSYsG1jyg8ejbyEVVd+hmNkxHOvyD7v6nSvhNM5tY+f5ESeEbu+6+xN3nuvvcejQYqCf6NnJSzSwXk3SvpE53/0Wfby2TdH3l8+slPVr/5gGNQ99Gbqp5y+ULkq6T9KqZvVyJ/UjS7ZIeNrMbJL0u6erGNLHxXnnllTAebcJftNl+f0ydOrWmxy9durTmNkBSCfp29PZPNENFivtVNOtEig+zKJo9NW7cuCQ2ZcqUMHf27NlJ7Nxzz01ib731Vvj4aAuMohk1LP1PnbCgu/szkoreVPxyfZsDNA99G7lhpSgAZIKCDgCZoKADQCay3Q9906ZNVefecccdYXzevHlJ7JFHHqn6eS+77LIwfvXV1Y2x/eY3vwnj27Ztq7oNKLdof/Dhw4eHudHg5ciRI8Pc0047LYnNmDGj6ueNJgxI0tixY5PYO++8k8SeeeaZ8PGrVq2q6vFS8T7pQxl36ACQCQo6AGSCgg4AmaCgA0AmKOgAkAlr5lJZM2vai0Wb6kvSypUrk1h0AnqzPfnkk0nse9/7Xpjb1dXV6OaUlrtXt1VinTWqb598cjoRbfLkyWHuwoULk9gVV1wR5kZL98eMGRPmRlsNFM086ejoSGJPPfVUEnviiSfCx0czuD788MMwt7e3N4znqpq+zR06AGSCgg4AmaCgA0AmKOgAkIlsl/4X7e18+eWXJ7EbbrghzI1OTP/a174W5j733HNJrGh581/+8pck9sILLyQxljYj8v7774fxDRs2JLFTTjklzI2W/p955plhbrT9QLRvuST961//SmKrV6+uKk+SDhw4kMQOHz4c5iLFHToAZIKCDgCZoKADQCaqOSR6ipn93cw6zWydmd1cif/EzHaY2cuVP/EKBqBF0beRm2oGRXsl3eLua8xstKTVZra88r1funu8mTjQ+ujbyEq/l/6b2aOS7tKRE9P396fTN3PpP4amWpb+t2LfbmtrS2LRUnwpnrkyYsSIMLfoOSKHDh1KYm+//XaYu3fv3iR28ODBqp5Tkpq5FUnZ1H3pv5m1S5oj6egcu++a2Vozu8/MTu93C4EWQd9GDqou6GY2StIfJX3f3fdK+rWksyTNltQj6c6Cxy0ys5fM7KU6tBeoO/o2clHVWy5mNkzS45L+5u6/CL7fLulxd/+PEzwPv0+hofr7lkur923ecsFRdXnLxcxM0r2SOvt2eDOb2CdtoaR030yghdG3kZsT3qGb2aWS/inpVUlH1+D+SNK1OvIrqUvaKulGd+85wXPx3y8aqj936GXo29Gy+yP/D6WiJfLc8eajmr6d7QEXGJpyO+CCgo6jOOACAIYQCjoAZIKCDgCZoKADQCayPeACyAGHO6A/uEMHgExQ0AEgExR0AMgEBR0AMtHsQdHdkrZVPh9X+To3XNfgmTaIr320b5fh72mgcr22MlxXVX27qUv//+2FzV5y97mD8uINxHUNbTn/PeV6bTldF2+5AEAmKOgAkInBLOhLBvG1G4nrGtpy/nvK9dqyua5Bew8dAFBfvOUCAJloekE3s6+a2UYz6zKzxc1+/XqqnAi/y8w6+sTGmtlyM9tc+Vi6E+PNbIqZ/d3MOs1snZndXImX/toaKZe+Tb8u37Ud1dSCbmZtkv5H0gJJsyRda2azmtmGOlsq6avHxBZLWuHu50haUfm6bHol3eLuMyXNk3RT5d8ph2triMz69lLRr0up2XfoF0vqcvct7n5Q0h8kXdXkNtSNuz8tac8x4ask3V/5/H5JX29qo+rA3XvcfU3l832SOiVNUgbX1kDZ9G36dfmu7ahmF/RJkrb3+bq7EsvJhKMHClc+jh/k9tTEzNolzZH0gjK7tjrLvW9n9W+fa79udkGPDjllmk2LMrNRkv4o6fvuvnew29Pi6NslkXO/bnZB75Y0pc/XkyXtbHIbGu1NM5soSZWPuwa5PQNiZsN0pNM/6O5/qoSzuLYGyb1vZ/Fvn3u/bnZBXyXpHDObbmbDJX1T0rImt6HRlkm6vvL59ZIeHcS2DIiZmaR7JXW6+y/6fKv019ZAufft0v/bD4V+3fSFRWZ2haT/ltQm6T53v62pDagjM3tI0nwd2a3tTUk/lvR/kh6WNFXS65KudvdjB5hampldKumfkl6VdPQMtB/pyPuNpb62Rsqlb9Ovy3dtR7FSFAAywUpRAMgEBR0AMkFBB4BMUNABIBMUdADIBAUdADJBQQeATFDQASAT/w+Hm4q+xVRLbQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2930aebe358>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 40; last KL=340.5096740722656; last RL=14156.3173828125\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAE1JJREFUeJzt3XuMVVWWx/HfsqgCBAQLBBF5BXEcHwEjmpHGqFFaRuOjoy3tHyMT25R/jIkmGpuYaKvJBGJsHZMmHRmhQaO2rdBC+g97DDHiGEXUkFYHHd5SUrzfimjhnj+4mBr2OtSt+z6b7ychVbVY9559qvZdderux7EQggAA+XdKvRsAAKgMCjoAJIKCDgCJoKADQCIo6ACQCAo6ACSCgg4AiaCgA0AiyiroZjbNzL40s7VmNrNSjQLqjb6NPLJSV4qaWZOk/5U0VVK7pJWS7ggh/E/lmgfUHn0bedWrjMdeJmltCGG9JJnZnyTdLCmz05sZ+wygqkIIVoGnoW+j4RTTt8t5y2WEpM1dvm4vxIC8o28jl8q5Qvd+W0RXKWbWJqmtjOMAtUbfRi6VU9DbJY3s8vXZkrYcnxRCmCtprsSfpcgN+jZyqZy3XFZKGm9mY82sRdKvJC2tTLOAuqJvI5dKvkIPIXSa2b2S/iapSdL8EMLnFWsZUCf0beRVydMWSzoYf5aiyio0y6XH6NuotmrPcgEANBAKOgAkgoIOAImgoANAIijoAJAICjoAJIKCDgCJoKADQCIo6ACQCAo6ACSCgg4AiaCgA0AiKOgAkAgKOgAkopw7FgE4CZxySnzd19TU5Ob26dMnivXt27eomCQ1NzdHsZaWlu6a+JPvvvvOjf/www9F537//fdR7PDhw27ukSNHolhnZ6ebW4utyrlCB4BEUNABIBEUdABIBAUdABJR1qComW2UdEDSEUmdIYRJlWgUjpo4cWKP4uVatWpVUbGTQSp928y/DaUX9wYkJWnQoEFRbOTIkW7uRRddFMXOO++8KDZkyBD38f37949i3iCl5A9q7tmzx809cOBAFNuwYYObu3Hjxii2c+dON9c7nncsyW/vjz/+6OaWqhKzXK4OIfhnC+QbfRu5wlsuAJCIcgt6kPRfZvaxmbVVokFAg6BvI3fKfcvlZyGELWY2VNJbZvZFCGF514TCi4EXBPKGvo3cKesKPYSwpfBxu6S/SLrMyZkbQpiU10ElnJzo28ijkq/QzayfpFNCCAcKn/9c0hOVatjMmTOj2OzZsyv19HXljeR7MwFeffVV9/GjRo0q+ljesu2skfVNmzZFsenTp7u5X375ZRQ7ePBg0e1qZNXu27WUNculd+/eUWzo0KFu7qRJ8e+ra665xs298MILi3refv36uY/3ltJ7Mclfzr9jxw4315ulMmDAADfXe81k8Zb5Zy399+LeOUilbxNQzlsuwyT9pdBhekl6OYTwZhnPBzQK+jZyqeSCHkJYL2lCBdsCNAT6NvKKaYsAkAgKOgAkomH3Q1+9enW9m9AjY8aMiWIzZsxwcydMiP+av/HGGyvdpB4bPXp0FPvggw/c3KVLl0axW2+9teJtQvG8AdBevfyXuLecP2tLialTpxad29raWlS7tm7d6j5+3759USzrHLzBy6x9y7141uQAr71Zg5fetgRZubXAFToAJIKCDgCJoKADQCIo6ACQCAo6ACSiYWe5LFmypN5N0LXXXhvFnnvuOTfXu9t51lLqnizHb1TnnntuvZuA43izM1paWtxcbzbKWWed5eYOHDgwimXddMK7acTmzZujmLd1hOTfHKJv375u7hlnnBHFmpqa3FxvW4qsmTZff/11FNu1a5eb+80330SxrO+Nt4VBqUv8s3CFDgCJoKADQCIo6ACQCAo6ACSiYQdFa2nKlClu/I9//GMUO/PMM6vdnIby0EMPuXFv6T/qyxsUzRok9JbTZ+07fujQoSjmDXRKUnt7exRbuXJlFNuyZYv7eG/CgDeAK/kDnd69BiRp7969RbfBi3uDtZK/pUDW97EWEx+4QgeARFDQASARFHQASAQFHQAS0W1BN7P5ZrbdzD7rEms1s7fMbE3h4+nVbSZQefRtpKaYWS4LJP1e0gtdYjMlLQshzDazmYWvf1P55tXGrFmz3HjWUuhyFXtX8ccff9yNP/FELm9A34gWKPG+ncWbEZO1DH3Hjh1RrLm52c3duHFjFNu/f38Uy3oNeMv8+/Xr5+Z6z+EtxZf8WS579uxxc73n+O6779xcb5l/1iyXWui2soQQlkvafVz4ZkkLC58vlHRLhdsFVB19G6kp9T30YSGEDkkqfPR3oQLyh76N3Kr6wiIza5PUVu3jALVG30ajKfUKfZuZDZekwsftWYkhhLkhhEkhhEklHguoJfo2cqvUK/SlkmZIml34WP/Ny8vw7LPPunHvzubevueV4C0LvvPOO93cq666KorNmzfPzX3ppZfKatdJKKm+7Q1+Sv7S/6x9vL0BwawtBYYMGVJUbtbA4aBBg6JY1pJ5r13bt/u/f72l+/v27XNze7LHude2rMHlSu997ilm2uIrkt6X9A9m1m5mv9bRzj7VzNZImlr4GsgV+jZS0+0Vegjhjoz/uqbCbQFqir6N1LBSFAASQUEHgERQ0AEgEdzgQtLrr7/uxkeMGBHFnnrqqWo35yejR48uOn7FFVe4ud6I/RtvvFFew9CQvBktWUvsvZknWbmnnXZaFMvaFmPs2LFRzJvd0dLS4j7ey922bZub29HREcV27drl5nqzXL799ls3tyfL+Xsyy6UWuEIHgERQ0AEgERR0AEgEBR0AEsGg6AnMmTMnimUNKD755JNRbPDgwW7u1VdfXV7DeuCFF16IYlOnTnVzV6xYUe3moMayBjq9Qb6sPc695fxZg6Le3uVeG7K2DvAGL7MGJA8fPhzFBgwY4OZ6Ojs73bh3vKztB+o5AOrhCh0AEkFBB4BEUNABIBEUdABIBIOiJ+ANmmzatMnNnT59etHPO2PGjCjW2toaxbyB1p7yBqkeffRRN3fu3LlRbMmSXG8HftLLGrTz+nbWnt/eispDhw65ud4AqDegmDVY661sPvXUU93ccePGFd2utWvXFt0Gr70MigIAaoqCDgCJoKADQCIo6ACQiGLuKTrfzLab2WddYo+Z2ddmtqrw7/rqNhOoPPo2UlPMLJcFkn4v6fg15M+EEGq3OXhCFi5cWFTehg0b3HhbW1sUu+6669xcbyR/2rRpbq53F/Vly5a5uQcPHnTjObNACfVtb8ZF1swV7+eXtZf4mjVroljWbBJvNog3cyXLwIEDo9jZZ5/t5nrbD1xwwQVurvdaynp9NdrMlZ7o9go9hLBc0u4atAWoKfo2UlPOe+j3mtnfC3+2nl6xFgH1R99GLpVa0P8gaZykiZI6JP0uK9HM2szsIzP7qMRjAbVE30ZulVTQQwjbQghHQgg/SvpPSZedIHduCGFSCGFSqY0EaoW+jTwraem/mQ0PIRy7Q+svJH12onyUJmvv9a1bt0axyZMnu7ne/tBZy5hvuummKJa1p3sig6KRPPdtbzAva89vbwB8x44dbq538+l169a5uXv37o1i3gBqr15+6Rk+fHgUu+iii9xcbwDV2w5AkiZNin/nvv/++26u933My0BptwXdzF6RdJWkIWbWLum3kq4ys4mSgqSNku6pYhuBqqBvIzXdFvQQwh1OeF4V2gLUFH0bqWGlKAAkgoIOAImgoANAIrjBRQ598MEHUWzx4sVurnczDeSbN+tE8rd5aG5udnNbWlqiWNY2Adu2bYtiWcv59+3bF8UOHz4cxbJmuXizZLybtEj+zJWsXG/7gD59+ri53iywvMxy4QodABJBQQeARFDQASARFHQASASDokDONDU1uXFvkO+0005zc08/Pd5EMmug0hsA9QYvs3J/+OGHKNaTc8jaqqJ///5RrHfv3m6ud27e9gcnOl4ecIUOAImgoANAIijoAJAICjoAJIKCDgCJSHaWy9133+3Gn3/++Rq3BCidt8w/azaKN6PFW/Iu+TeH8JboS/7NTLyZK1Lxy+azZrl4M1fOP/98N3fUqFFRLGurg/Xr10exnTt3urnMcgEA1B0FHQASQUEHgER0W9DNbKSZvW1mq83sczO7rxBvNbO3zGxN4WO89AxoYPRtpKaYQdFOSQ+EED4xswGSPjaztyT9q6RlIYTZZjZT0kxJv6leU7M99thjUeyRRx5xcy+99NIods89jXkf4IkTJ7rx22+/PYpl7Xvu7ZGNnzR83/Z+ft5e5pJ0xhlnRLHhw4cXnZs1KNrZ2Vl0rtc2b2B36NCh7uOvvfbaKHbDDTe4ud45bNmyxc198803o9iuXbvc3Lzsfe7p9tUeQugIIXxS+PyApNWSRki6WdLCQtpCSbdUq5FANdC3kZoeXb6Z2RhJF0taIWlYCKFDOvrCkOT/ygVygL6NFBQ9D93M+ktaJOn+EML+rNtgOY9rk9RWWvOA6qNvIxVFXaGbWbOOdviXQgjHbl65zcyGF/5/uKTt3mNDCHNDCJNCCPENAIE6o28jJcXMcjFJ8yStDiE83eW/lko6NhI3Q9KSyjcPqB76NlJTzFsuP5P0L5I+NbNVhdjDkmZL+rOZ/VrSV5J+WZ0mds8blc5avnvXXXdFsXPPPbfo560E709671ijR492H+8tee7JcuWsu7vPmTMninl3fE9Iw/TtrLd5vCXyWbNc+vbtG8Wylv6fddZZRR1L8rcJ2L17t5vrnceQIUOi2JQpU9zHX3755UU9XvKX7r/88stu7vLly6NY1kydPOu2oIcQ/ltS1puK11S2OUDt0LeRGiYpA0AiKOgAkAgKOgAkwmq5zNXMqnKwK6+8MootWrTIzfUGeLKWx1drX2TveLU81jPPPOPmPvjgg1VpQy2FEIqbRF5h5fbtrD7Yp0+fKHb66f7WMuPHj49i3iCjJF1yySVRzFtKn9UGb+91yX99eXucDxgwwH28t81Ae3u7mzt//vwo9uKLL7q5mzdvjmJHjhxxcxtVMX2bK3QASAQFHQASQUEHgERQ0AEgERR0AEhE0bstNrJ33nknik2YMMHNbWuLN8e77bbb3NysLQEa0RdffOHGFy9eHMVmzZpV7eagh7Jmm3kzMQ4dOuTmbt8e7yHW0dHh5npL91tbW93cYcOGRbHBgwe7ud6MFm8GlzfrRJLefffdKPbaa6+5uR9++GEU874HUv5mtJSKK3QASAQFHQASQUEHgERQ0AEgEUks/S/XOeec48YnT55cleMVux96T7z33ntufN26dWU9b97kdel/Fm9LgObmZjfXW04/dKh/O9SxY8dGsXHjxrm5Z555ZhTr1cufT+EN2G7YsCGKZQ3if/XVV1Esa+91b1//am2h0QhY+g8AJxEKOgAkgoIOAIko5ibRI83sbTNbbWafm9l9hfhjZva1ma0q/Lu++s0FKoe+jdQUs1K0U9IDIYRPzGyApI/N7K3C/z0TQniqes0Dqoq+jaQUc5PoDkkdhc8PmNlqSSOq3bBaWrt2bY/iSEMe+rY3ayPrbvXerI89e/a4uWvWrIliWTfZKLZdWXEvVsvZdSeTHr2HbmZjJF0saUUhdK+Z/d3M5puZfxsVIAfo20hB0QXdzPpLWiTp/hDCfkl/kDRO0kQdvcr5Xcbj2szsIzP7qALtBSqOvo1UFLWwyMyaJf1V0t9CCE87/z9G0l9DCBd28zz8nYWq6unCopT6trdgLettFC/OWy6NrSILi+xoL5knaXXXDm9mw7uk/ULSZ6U0EqgX+jZS0+0VuplNkfSupE8lHftV+7CkO3T0T9IgaaOkewqDTCd6Ln4to6p6coVO30aeFNO32csFSUltLxfgGPZyAYCTCAUdABJBQQeARFDQASARFHQASAQFHQASQUEHgERQ0AEgERR0AEhEMTe4qKSdkjYVPh9S+Do1nFf9jK7jsY/17Tx8n0qV6rnl4byK6ts1Xfr//w5s9lEIYVJdDl5FnNfJLeXvU6rnltJ58ZYLACSCgg4AiahnQZ9bx2NXE+d1ckv5+5TquSVzXnV7Dx0AUFm85QIAiah5QTezaWb2pZmtNbOZtT5+JRXuCL/dzD7rEms1s7fMbE3hY+7uGG9mI83sbTNbbWafm9l9hXjuz62aUunb9Ov8ndsxNS3oZtYkaY6kf5Z0vqQ7zOz8WrahwhZImnZcbKakZSGE8ZKWFb7Om05JD4QQ/lHSP0n6t8LPKYVzq4rE+vYC0a9zqdZX6JdJWhtCWB9C+F7SnyTdXOM2VEwIYbmk3ceFb5a0sPD5Qkm31LRRFRBC6AghfFL4/ICk1ZJGKIFzq6Jk+jb9On/ndkytC/oISZu7fN1eiKVk2LEbChc+Dq1ze8piZmMkXSxphRI7twpLvW8n9bNPtV/XuqB7Nzllmk2DMrP+khZJuj+EsL/e7Wlw9O2cSLlf17qgt0sa2eXrsyVtqXEbqm2bmQ2XpMLH7XVuT0nMrFlHO/1LIYTFhXAS51YlqfftJH72qffrWhf0lZLGm9lYM2uR9CtJS2vchmpbKmlG4fMZkpbUsS0lMTOTNE/S6hDC013+K/fnVkWp9+3c/+xPhn5d84VFZna9pP+Q1CRpfgjh32vagAoys1ckXaWju7Vtk/RbSW9I+rOkUZK+kvTLEMLxA0wNzcymSHpX0qeSfiyEH9bR9xtzfW7VlErfpl/n79yOYaUoACSClaIAkAgKOgAkgoIOAImgoANAIijoAJAICjoAJIKCDgCJoKADQCL+D2z5pJLIIoFFAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2930ae24898>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"for epoch in tqdm.tqdm_notebook(range(50)):\n",
" for images, _ in train_loader:\n",
" images = images.view(-1, 28*28)\n",
" images = images.cuda()\n",
" predicted_images, (z_mu, z_sigma) = vae(images)\n",
" if PROBA_OUT:\n",
" predicted_images = predicted_images.view(images.shape[0], 2, -1)\n",
" mus, stds = predicted_images[:, 0], torch.exp(predicted_images[:, 1])\n",
" predicted_images = mus\n",
" reconstruction_loss = -torch.sum(gaussian_log_prob(images, mus, stds))\n",
" elif SSE_RECON_LOSS:\n",
" reconstruction_loss = torch.sum((predicted_images - images)**2)\n",
" else:\n",
" reconstruction_loss = F.binary_cross_entropy_with_logits(\n",
" predicted_images,images, reduction=\"sum\")\n",
" kl_loss = -0.5 * torch.sum(torch.sum(1 + torch.log(z_sigma) - z_mu*z_mu - z_sigma, dim=1))\n",
" loss = kl_loss + reconstruction_loss\n",
" loss.backward()\n",
" optim.step()\n",
" vae.zero_grad()\n",
"\n",
" if epoch % 10 == 0:\n",
" print(f\"epoch {epoch}; last KL={kl_loss.item()}; last RL={reconstruction_loss.item()}\")\n",
" images = images.view(images.shape[0], 1, 28, 28)\n",
" if DO_SIGMOID:\n",
" predicted_images = torch.sigmoid(predicted_images)\n",
" predicted_images = predicted_images.view(images.shape[0], 1, 28, 28)\n",
" plt.subplot(121)\n",
" plt.imshow(images[0, 0].cpu().numpy(), cmap=\"gray\")\n",
" plt.subplot(122)\n",
" plt.imshow(predicted_images[0, 0].cpu().detach().numpy(), cmap=\"gray\")\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sampling in latent space"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAC7CAYAAAB1qmWGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAEPNJREFUeJzt3XmsVGWax/HfIyAKgkEJS2gYTCturWJEHGNrnMhi8w92ItoCBg2RdotLNDZbbCPBEEdpjT0xAUHAhbYTukeC2wBBaMlIFEUFri0GGEFR4ihLZIdn/qDIXHnfgrq3zqm65+X7SUhVPfXWOc+596mHc89q7i4AQPGdVO8EAADZoKEDQCJo6ACQCBo6ACSChg4AiaChA0AiaOgAkAgaOgAkoqqGbmbXm9k/zexLMxubVVJAvVHbKCJr7pmiZtZK0heSBkraLOkDSbe4+9rs0gNqj9pGUbWu4rP9JX3p7uslycz+ImmopLJFb2ZcZwC5cnfLYDLUNlqcSmq7mk0uPSRtavR6cykGFB21jUKqZg099r9FsJZiZmMkjaliPkCtUdsopGoa+mZJPRu9/oWkb44e5O7TJE2T+LMUhUFto5Cq2eTygaRzzOwsMztZ0u8kzc8mLaCuqG0UUrPX0N39gJndK+kdSa0kzXT3NZllBtQJtY2iavZhi82aGX+WImcZHeXSZNQ28pb3US4AgBaEhg4AiaChA0AiaOgAkAgaOgAkgoYOAImgoQNAImjoAJAIGjoAJIKGDgCJoKEDQCJo6ACQCBo6ACSChg4AiajmjkVopvPOOy+ITZ8+PYi1adMm+vlBgwYFsR07dlSfGIBCYw0dABJBQweARNDQASARNHQASERVO0XNbKOknZIOSjrg7v2ySCoVZvFbAN53331B7KqrrgpiX331VZOmi+xQ2/+vZ8+eQWzFihXRsfPmzQtiEyZMCGLsxM9HFke5/Ju7f5/BdICWhtpGobDJBQASUW1Dd0n/ZWYrzWxMFgkBLQS1jcKpdpPLVe7+jZl1kbTQzD5392WNB5S+DHwhUDTUNgqnqjV0d/+m9LhV0t8l9Y+Mmebu/U7knUooHmobRdTsNXQzay/pJHffWXo+SNLjmWWWgCuuuCIav+uuu4LY0qVLg9jy5cujnx85cmQQmzNnTnTs3r17g9i+ffuiY3EYtf1zo0ePDmLdunWLjr3nnnuCWOyImJdffrn6xCL69Yv/3zp58uQg9uqrr0bHvvTSS0Hs0KFD1SVWI9Vscukq6e+lQ+haS3rV3d/OJCugvqhtFFKzG7q7r5d0SYa5AC0CtY2i4rBFAEgEDR0AEmHuXruZmdVuZi3AiBEjovHYDqH9+/cHsXLXQ2+KL774IojNnDkzOvbZZ58NYnv27Kk6h1py97pcFyGF2u7fPziQR5L03HPPBbHLL7+84ul+/314su0jjzwSHTtr1qyKpxszZcqUaLzc/GLatm0bxGLfz1qrpLZZQweARNDQASARNHQASAQNHQASQUMHgERkcT10SGrVqlUQGz58eHRs7DTiUaNGBbFrrrmm6ry6dOkSxModCRA7bXrYsGFV54BiuPHGG6PxphzR8vnnnwexk08+OYiNGzcu+vlFixYFsc2bN1c8/xMda+gAkAgaOgAkgoYOAImgoQNAItgpmpGuXbsGsSFDhkTHbtiwIYjNnTu3olhTdezYMYjFLgcgxXeKnXrqqdGxu3fvri4x1FWPHj2CWOy65+V8/fXX0fiVV14ZxGKXF3nllVein7/99tuD2KRJkyrO60THGjoAJIKGDgCJoKEDQCJo6ACQiOM2dDObaWZbzWx1o9gZZrbQzNaVHjvlmyaQPWobqankKJdZkv4sqfFt5cdKWuzuU8xsbOn1H7JPrzjKHdES88Ybb+SYyc/t2LEjiC1YsCA6duzYsUHspptuio6dPXt2dYm1DLOUeG2fdFJ8nW369OlBrFOn+P9dsUtVPPPMM9Gx27dvryivm2++ORqPXUKjKS677LKqPl90x11Dd/dlkn44KjxU0pFv9GxJN2ScF5A7ahupae429K7uvkWSSo/hFaCAYqK2UVi5n1hkZmMkjcl7PkCtUdtoaZq7hv6dmXWXpNLj1nID3X2au/dz9/DarEDLQ22jsCx2Wm4wyKy3pAXu/qvS63+X9L+Ndhyd4e7Hva12CndGb9euXTS+evXqINa5c+fo2AsvvDCIbdq0qbrEmqBDhw7ReOySBD/++GN07EUXXRTE9uzZU11iGajkzuiNpVTbsR2gjz32WHTsxIkTK57uxx9/HMRaws7HM888M4iVu3Z627ZtK55ubOz+/fsrTywnldR2JYctzpX035LONbPNZjZa0hRJA81snaSBpddAoVDbSM1xt6G7+y1l3rou41yAmqK2kRrOFAWARNDQASARNHQASAQ3uGii9u3bR+NnnXVWECt35Eotj2iJ2blzZzT+1ltvBbGRI0dGx55yyilBrCUc5XIiO//884NYU45mOXjwYDT+1FNPNTunPD388MNBrClHs7z++uvReLmfQxGwhg4AiaChA0AiaOgAkAgaOgAkgp2iTTRgwICKxy5cuDDHTLK3bdu2eqeACvTp0ycar7benn766Wh87ty5VU23Wr169YrGR40aVdV0ly5dGo3HLo1R6XXe6401dABIBA0dABJBQweARNDQASAR7BQ9htatwx/P6NGjo2MPHDgQxF588cXMc8rT2rVrg9j7778fHVvubFPk79xzz43Gu3XrVvE0Yjcqf/zxx5udU57atGkTjce+n00xderUaHzSpElB7Oqrr46OXbVqVVU5ZI01dABIBA0dABJBQweARNDQASARldxTdKaZbTWz1Y1ij5nZ12a2qvRvSL5pAtmjtpGaSnYTz5L0Z0lzjor/yd1b5oWSMxI7Bfi66+K3m9yyZUsQe++99zLPKU+XXHJJEPvpp5+iY4t8zehGZqmAtT1s2LCqpxE7wmPfvn3RsbfeemvV8ztaucsXDBo0KIjFvoeS1Llz50xzOmLOnKPLQdq4cWMu88racdfQ3X2ZpB9qkAtQU9Q2UlPNNvR7zezT0p+tnTLLCKg/ahuF1NyG/rykX0rqK2mLpPhl2iSZ2Rgz+9DMPmzmvIBaorZRWM1q6O7+nbsfdPdDkqZL6n+MsdPcvZ+792tukkCtUNsosmadO2tm3d39yF7A30pafazxaFl69+4djY8YMSKIPfnkkzln07IUoba7du1a9TTmz58fxMrt6O7YsWPV86u3NWvWBLGhQ4dGx27YsCGIuXvmOeXhuA3dzOZKulZSZzPbLOmPkq41s76SXNJGSb/PMUcgF9Q2UnPchu7ut0TCM3LIBagpahup4UxRAEgEDR0AEkFDB4BEcIOLjLz55pv1TqFi48aNi8ZjRwI88cQTeaeDJnr77bej8YEDB1Y8jfbt22eVTrN88skn0fjevXuDWP/+ZY8crdijjz4axNavX1/1dFsa1tABIBE0dABIBA0dABJBQweARLBT9Bj69u1b8dhdu3blmEnzDR48OIiVu572DTfcEMQSue55UmbOnBmN9+rVK4iVu5Z4Xl544YUgtmfPniC2bt266OdjO2u//fbbiuf/7rvvRuNLliypeBpFxho6ACSChg4AiaChA0AiaOgAkAgaOgAkgqNcjqGhoaHeKVSs3E0Ipk2bFsTKLdeyZcsyzQn52L59ezT+4IMP1jiT7FV784533nknGt+2bVtV0y0K1tABIBE0dABIBA0dABJx3IZuZj3NbImZNZjZGjO7vxQ/w8wWmtm60mOn/NMFskNtIzWV7BQ9IOkhd//IzDpIWmlmCyXdJmmxu08xs7GSxkr6Q36p1l7s2szldOvWLcdMfq5fv35B7Pnnn4+OPf3004PY8OHDM8+poE7Y2m6pbrvttiBmZtGx7p5zNsVz3DV0d9/i7h+Vnu+U1CCph6ShkmaXhs2WFF4IBGjBqG2kpknb0M2st6RLJa2Q1NXdt0iHvxiSumSdHFAr1DZSUPFx6GZ2mqR5kh5w9x3l/gyKfG6MpDHNSw/IH7WNVFS0hm5mbXS44F9x97+Vwt+ZWffS+90lbY191t2nuXs/dw83/AJ1Rm0jJZUc5WKSZkhqcPepjd6aL2lU6fkoSa9nnx6QH2obqalkk8tVkm6V9JmZrSrFxkuaIumvZjZa0leS4ndNOEEMGDAgiF1wwQXRsWvXrg1iZ599dhCbMGFC9PMjRowIYrt3746OveOOO4LY8uXLo2NPQNR2ATTlaJZ27dpF43feeWcQK3cDm7vvvjuIHTp0qOIc6um4Dd3d35NUbqPiddmmA9QOtY3UcKYoACSChg4AiaChA0AiuB76MezatSuILVq0KDo2tlP0008/jY49cOBAEGvdOvxVtGrVKvr5FStWBLExY+KHQ5fLAUjRxIkTo/GlS5cGsYEDB0bHFmUHaAxr6ACQCBo6ACSChg4AiaChA0AiaOgAkAiOcjmG2A0uBg8eHB07ZMiQIDZ58uTo2IsvvjiILV68OIjNmDEj+vnXXnstiBV5zzxwxKZNm6r6/L59+6Lx2Hfp4MGDVc2rJWINHQASQUMHgETQ0AEgETR0AEiE1fLO2WbGbbqRK3ev7P5xGaO2s9GhQ4cgtnLlyujY2D0EpkyZEh07fvz46hJrASqpbdbQASARNHQASAQNHQASUclNonua2RIzazCzNWZ2fyn+mJl9bWarSv/CM2uAFozaRmoqOVP0gKSH3P0jM+sgaaWZLSy99yd3fyq/9IBcUdtISiU3id4iaUvp+U4za5DUI+/EgLxR2y3Pzp07g1ifPn3qkEkxNWkbupn1lnSppCO3zLnXzD41s5lm1inj3ICaobaRgoobupmdJmmepAfcfYek5yX9UlJfHV7LebrM58aY2Ydm9mEG+QKZo7aRiopOLDKzNpIWSHrH3adG3u8taYG7/+o40+HkC+SqqScWUdsoikxOLDIzkzRDUkPjgjez7o2G/VbS6uYkCdQLtY3UHHcN3cx+Lekfkj6TdOSi2+Ml3aLDf5K6pI2Sfl/ayXSsabEWg1w1ZQ2d2kaRVFLbXMsFSeFaLkgV13IBgBMIDR0AEkFDB4BE0NABIBE0dABIBA0dABJBQweARNDQASARNHQASEQlN7jI0veS/qf0vHPpdWpYrvr5lzrO+0htF+Hn1FypLlsRlqui2q7pqf8/m7HZh+7ery4zzxHLdWJL+eeU6rKltFxscgGARNDQASAR9Wzo0+o47zyxXCe2lH9OqS5bMstVt23oAIBssckFABJR84ZuZteb2T/N7EszG1vr+WepdEf4rWa2ulHsDDNbaGbrSo+Fu2O8mfU0syVm1mBma8zs/lK88MuWp1Rqm7ou3rIdUdOGbmatJP2HpN9IukDSLWZ2QS1zyNgsSdcfFRsrabG7nyNpcel10RyQ9JC7ny/pXyXdU/o9pbBsuUistmeJui6kWq+h95f0pbuvd/d9kv4iaWiNc8iMuy+T9MNR4aGSZpeez5Z0Q02TyoC7b3H3j0rPd0pqkNRDCSxbjpKpbeq6eMt2RK0beg9Jmxq93lyKpaTrkRsKlx671DmfqphZb0mXSlqhxJYtY6nXdlK/+1TrutYNPXaTUw6zaaHM7DRJ8yQ94O476p1PC0dtF0TKdV3rhr5ZUs9Gr38h6Zsa55C378ysuySVHrfWOZ9mMbM2Olz0r7j730rhJJYtJ6nXdhK/+9TrutYN/QNJ55jZWWZ2sqTfSZpf4xzyNl/SqNLzUZJer2MuzWJmJmmGpAZ3n9rorcIvW45Sr+3C/+5PhLqu+YlFZjZE0jOSWkma6e6Ta5pAhsxsrqRrdfhqbd9J+qOk/5T0V0m9JH0laZi7H72DqUUzs19L+oekzyQdKoXH6/D2xkIvW55SqW3qunjLdgRnigJAIjhTFAASQUMHgETQ0AEgETR0AEgEDR0AEkFDB4BE0NABIBE0dABIxP8BEUk23tWL+nAAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2933a079080>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x2930aefab38>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"images = images.view(-1, 28 * 28)\n",
"image_a = images[0].unsqueeze(0)\n",
"image_b = images[1].unsqueeze(0)\n",
"plt.subplot(121)\n",
"plt.imshow(image_a.view(28, 28).cpu().numpy(), cmap=\"gray\")\n",
"plt.subplot(122)\n",
"plt.imshow(image_b.view(28, 28).cpu().numpy(), cmap=\"gray\")\n",
"plt.show()\n",
"z_a = vae(image_a)[1][0]\n",
"z_b = vae(image_b)[1][0]\n",
"\n",
"# linearly interpolate from image_a to image_b (inclusive)\n",
"n_interps = 5\n",
"interps_z = z_a + (z_b - z_a) * torch.linspace(0, 1, n_interps).view(n_interps, 1).cuda()\n",
"interps_images = vae.decoder(interps_z).detach()\n",
"if DO_SIGMOID:\n",
" interps_images = torch.sigmoid(interps_images)\n",
"interps_images = interps_images.view(n_interps, bool(PROBA_OUT)+1, 28, 28)\n",
"_, axs = plt.subplots(1, n_interps)\n",
"for ax, image in zip(axs, interps_images):\n",
" ax.imshow(image[0].cpu().numpy(), cmap=\"gray\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment