Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created August 23, 2021 21:31
Show Gist options
  • Save dominicrufa/8565c2428cad35fe1738c8f8139ed506 to your computer and use it in GitHub Desktop.
Save dominicrufa/8565c2428cad35fe1738c8f8139ed506 to your computer and use it in GitHub Desktop.
realnvp integrator in n dims without equivariance
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "24760d0f",
"metadata": {},
"source": [
"# Toy phase space normalizing flow\n",
"I'm going to write out an augmented NF on gaussians in low dimensions and then scale to higher dimensions and see how it performs."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "000bc1b3",
"metadata": {},
"outputs": [],
"source": [
"from typing import Sequence, Callable, Dict, Tuple, Optional, NamedTuple, Any\n",
"import jax\n",
"import flax.linen as nn\n",
"import jax.numpy as jnp\n",
"from functools import partial\n",
"from jax import lax, ops, vmap, jit, grad, random, value_and_grad\n",
"from jax.scipy.special import logsumexp\n",
"from jraph._src.utils import ArrayTree\n",
"from jax.experimental import optimizers\n",
"import jraph\n",
"from jax.tree_util import tree_map\n",
"\n",
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "52be46a8",
"metadata": {},
"outputs": [],
"source": [
"class TanhMLP(nn.Module):\n",
" \"\"\"A flax MLP.\"\"\"\n",
" features: Sequence[int]\n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" x = inputs\n",
" for i, lyr in enumerate([nn.Dense(feat, dtype=jnp.float64) for feat in self.features]):\n",
" x = lyr(x)\n",
" if i != len(self.features) - 1:\n",
" x = nn.tanh(x)\n",
" return x\n",
"\n",
"class ReluMLP(nn.Module):\n",
" \"\"\"A flax MLP.\"\"\"\n",
" features: Sequence[int]\n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" x = inputs\n",
" for i, lyr in enumerate([nn.Dense(feat, dtype=jnp.float64) for feat in self.features]):\n",
" x = lyr(x)\n",
" if i != len(self.features) - 1:\n",
" x = nn.relu(x)\n",
" return x\n",
"\n",
"class SwishMLP(nn.Module):\n",
" \"\"\"A flax MLP.\"\"\"\n",
" features: Sequence[int]\n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" x = inputs\n",
" for i, lyr in enumerate([nn.Dense(feat, dtype=jnp.float64) for feat in self.features]):\n",
" x = lyr(x)\n",
" if i != len(self.features) - 1:\n",
" x = nn.swish(x)\n",
" return x\n",
"\n",
" \n",
"\n",
" \n",
"def make_mlp(features : Dict[str, Sequence[int]], # the hidden/output layer sizes\n",
" nn_module : Optional[Any] = SwishMLP # the module class\n",
" ) -> Callable[[ArrayTree], ArrayTree]:\n",
" \"\"\"wrap an mlp generator\"\"\"\n",
" # TODO: change the typing of this function (specifically the `nn_module` arg)\n",
" @jraph.concatenated_args\n",
" def update_fn(inputs):\n",
" return nn_module(**features)(inputs)\n",
" return update_fn\n",
"\n",
"def make_precision64(params):\n",
" return tree_map(lambda x: jnp.array(x, dtype=jnp.float64), params)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bd0b21ba",
"metadata": {},
"outputs": [],
"source": [
"class NVPForward(nn.Module):\n",
" \"\"\"\n",
" an nvp forward\n",
" \"\"\"\n",
" log_scale_features : Sequence[int]\n",
" translate_features : Sequence[int]\n",
" \n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" N = inputs.shape[0]\n",
" d = inputs.shape[-1] // 2\n",
" x, v = inputs[:,:d], inputs[:,d:]\n",
" \n",
" #update velocities from positions\n",
" vlog_scale = make_mlp({'features': self.log_scale_features})(x)\n",
" vtranslate = make_mlp({'features': self.translate_features})(x)\n",
" \n",
" updated_vs = v * jnp.exp(vlog_scale) + vtranslate\n",
" vlogdetJ = jnp.sum(vlog_scale, axis=-1)\n",
" \n",
" # update positions from updated velocities\n",
" #xlog_scale = make_mlp({'features': self.log_scale_features})(updated_vs)\n",
" #xtranslate = make_mlp({'features': self.translate_features})(updated_vs)\n",
" #updated_xs = x * jnp.exp(xlog_scale) + xtranslate\n",
" #xlogdetJ = jnp.sum(xlog_scale, axis=-1)\n",
" updated_xs = x + updated_vs\n",
" \n",
" # another one!\n",
" vlog_scale2 = make_mlp({'features': self.log_scale_features})(updated_xs)\n",
" vtranslate2 = make_mlp({'features': self.translate_features})(updated_xs)\n",
" \n",
" updated_vs2 = updated_vs * jnp.exp(vlog_scale2) + vtranslate2\n",
" vlogdetJ2 = jnp.sum(vlog_scale2, axis=-1)\n",
" \n",
" #return jnp.concatenate([updated_xs, updated_vs], axis=-1), vlogdetJ\n",
" return jnp.concatenate([updated_xs, updated_vs2], axis=-1), vlogdetJ + vlogdetJ2"
]
},
{
"cell_type": "markdown",
"id": "c9b24a51",
"metadata": {},
"source": [
"let's maybe make a test for this..."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f9e80442",
"metadata": {},
"outputs": [],
"source": [
"class SingleAlternativeRNVP(nn.Module):\n",
" \"\"\"\n",
" a full alternation of nvp forward\n",
" \"\"\"\n",
" log_scale_features : Sequence[int]\n",
" translate_features : Sequence[int]\n",
" \n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" out_fixed0, logdetJ_fixed0 = make_mlp(features = {'log_scale_features': self.log_scale_features,\n",
" 'translate_features': self.translate_features}, nn_module = NVPForward)(inputs)\n",
" out_fixed1, logdetJ_fixed1 = make_mlp(features = {'log_scale_features': self.log_scale_features,\n",
" 'translate_features': self.translate_features}, nn_module = NVPForward)(out_fixed0)\n",
" \n",
" return out_fixed1, logdetJ_fixed0 + logdetJ_fixed1 "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e8c5fecc",
"metadata": {},
"outputs": [],
"source": [
"class AlternativeRNVP(nn.Module):\n",
" \"\"\"\n",
" N full alternations of nvp forward\n",
" \"\"\"\n",
" log_scale_features : Sequence[int]\n",
" translate_features : Sequence[int]\n",
" num_layers : int\n",
" \n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
" outputs = inputs\n",
" logdetJs = jnp.zeros(inputs.shape[0])\n",
" for idx in range(self.num_layers):\n",
" outputs, _logdetJs = make_mlp(features = {'log_scale_features': self.log_scale_features,\n",
" 'translate_features': self.translate_features}, nn_module = NVPForward)(outputs)\n",
" logdetJs = logdetJs + _logdetJs\n",
" \n",
" return outputs, logdetJs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd500362",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5916f1db",
"metadata": {},
"outputs": [],
"source": [
"def unnormalized_Normal_logp(x, mu, cov): #tested\n",
" \"\"\"\n",
" compute an unnormalized gaussian logp\n",
" arguments\n",
" x : jnp.array(Dx)\n",
" position\n",
" mu : jnp.array(Dx)\n",
" mean vector\n",
" cov : jnp.array(Dx)\n",
" covariance vector\n",
" returns\n",
" out : float\n",
" unnormalized gaussian logp\n",
" \"\"\"\n",
" delta = x-mu\n",
" return -0.5*(delta/cov).dot(delta)\n",
"\n",
"def unnormalized_gmm_logp(x, mus, covs, lws):\n",
" \"\"\"\n",
" return unnormalized gaussian mixture model logp\n",
" \"\"\"\n",
" dim = len(x)\n",
" def mapper(entry):\n",
" _mu, _cov = entry[:dim], entry[dim:]\n",
" return unnormalized_Normal_logp(x, _mu, _cov)\n",
" \n",
" unnorm_logps = lax.map(mapper, jnp.hstack((mus, covs)))\n",
" weighted_ps = jnp.exp(lws + unnorm_logps).sum()\n",
" return jnp.log(weighted_ps)\n",
"\n",
"#samplers\n",
"def sample_normal(seed, N, mu, cov):\n",
" \"\"\"\n",
" sample a normal distribution\n",
" \"\"\"\n",
" dim = len(mu)\n",
" return random.normal(seed, shape=(N,dim)) * jnp.sqrt(cov) + mu\n",
"\n",
"def sample_gmm(seed, mus, covs, lws):\n",
" from jax.scipy.special import logsumexp\n",
" num_mixtures = len(lws)\n",
" num_mixtures, dim = mus.shape\n",
" seed1, seed2 = random.split(seed)\n",
" mixture_idx = random.choice(seed1, len(lws), p=jnp.exp(lws - logsumexp(lws)))\n",
" return random.normal(seed2, shape=(dim,)) * jnp.sqrt(covs[mixture_idx]) + mus[mixture_idx]\n",
"\n",
"def free_energy(works):\n",
" \"\"\"\n",
" compute the free energy from a work array\n",
" \"\"\"\n",
" from jax.scipy.special import logsumexp\n",
" N = len(works)\n",
" w_min = jnp.min(works)\n",
" return w_min - logsumexp(-works + w_min) + jnp.log(N)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cda0f311",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"og_seed = random.PRNGKey(1)\n",
"mut_seed = random.split(og_seed, num=1000)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9f9dbcf2",
"metadata": {},
"outputs": [],
"source": [
"vsample_gmm = vmap(sample_gmm, in_axes=(0, None, None, None))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8296d0ec",
"metadata": {},
"outputs": [],
"source": [
"data = vsample_gmm(mut_seed, jnp.array([[0., 0.], [2., 2.], [-2., 2]]), jnp.array([[.1, .1],[0.1, .1], [0.1, 0.1]]), jnp.array([0., 0., 0.]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "df8ee349",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([1.67108706, 1.82428184], dtype=float64)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "af126246",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-0.69530298, dtype=float64)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unnormalized_gmm_logp(data[0], \n",
" jnp.array([[0., 0.], [2., 2.], [-2., 2]]), \n",
" jnp.array([[.1, .1],[0.1, .1], [0.1, 0.1]]), \n",
" jnp.array([0., 0., 0.]))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f312f9a4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7f1854072910>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAApoElEQVR4nO2df4wc53nfv8/uDalZGuWe4itirnkm4bpkzNLkmQeJAftHqQaibEXKhbTEyFJboEWFAGlRCsShp0gQqUYuDzg4ZlH4HwE2gkCsQsmUtpJphE5BFm7pUvExxzPNimwsy6K0VGvG5MoWb0nu7b39Y2+Ws7Pzzo/d2Z2Z3e8HCGIe92bfoWa+88zzfp/nEaUUCCGEpJdM3AsghBDSGRRyQghJORRyQghJORRyQghJORRyQghJOUNxfOknP/lJtW7duji+mhBCUsvZs2f/Tik14vx5LEK+bt06zM7OxvHVhBCSWkTkPbefM7VCCCEph0JOCCEph0JOCCEph0JOCCEph0JOCCEpJxbXCuk9xbkSZk5cwpVyBWvyJiZ3bcDEWCHuZRFCIoBCPgAU50p4+rXzqFRrAIBSuYKnXzsPABRzQvoAplYGgJkTlxoiblGp1jBz4lJMKyKERAmFfAC4Uq6E+jkhJF1QyAeANXkz1M8JIemCQj4ATO7aANPINv3MNLKY3LUhphURQqKEm50DgLWhSdcKIf0JhXxAmBgrULgJ6VOYWiGEkJRDISeEkJRDISeEkJRDISeEkJRDISeEkJRDISeEkJRDISeEkJRDH3nKYXtaQgiFPMWwPS0hBGBqJdWwPS0hBIhAyEXkLhH5axGZF5ELIvJ8FAsj/rA9LSEEiCYivwXgPqXUFgBbATwgItsjOC7xge1pCSFABEKu6ny8/Edj+f9Up8cl/rA9LSEEiGizU0SyAM4C+AcAvqmUesvlM08CeBIARkdHo/jagYftaQlpZRCdXKJUdMGziOQBvA7g3yqlfqL73Pj4uJqdnY3sewkhBGh1cgH1t9RDuzf3hZiLyFml1Ljz55G6VpRSZQD/HcADUR633yjOlbBj+iTWTx3HjumTKM6V4l4SIX3BoDq5onCtjCxH4hARE8DvALjY6XH7FStiKJUrULjj/aaYE9I5g+rkiiIi/xSAUyLyYwA/AvBXSqnvRnDcvmRQIwZCesGgOrk63uxUSv0YwFgEaxkIgkYMg7hhQ0inTO7a4Joj73cnF0v0e8yavImSi5hnRFCcK2FirMDSe0LaJIyTSxcspTGIitS1EpRBcK14XSTOiMHC2l2fOXHJVewLeROnp+7r2VoJSSqdXrM6d8uebQUcO1tKrOtF51qhkHcBPwtUca6E/a/Mo+byb1/Im7iyvBHqhgCRim2/27VIcmlXjKO4ZndMn3QNlrIi2vuyG0FUWHpiPyR1/DY0J8YKWNI8QK2LWkfUThduvpI46MS9pbtmD75xIbCtV7dX5SbiXp9PChTyLhBkQ9Nrd92t9N5JVGI7qHYtEi/tBBDFuRK2Pv9910gaAMqVauAHQ1gXS9JdLxTyLhDEAuXVJ2VirIBDuzejkDchHt9TKlc6jsoH1a5F4iVsAFGcK2Hy1XmUK9XA31Gp1rD/lXnXCD1IsGSRBtcLhbwLBGlm5RTrQt7Enm0FzJy4hPVTx/H8mxdw49YigHreTkenKRY23iJxEDaAmDlxCdWl8Pt5NaVcI3Tr/tPdW1mRxn2Zhv0i2g87QLdZE9QCZf+scwPn+sKdyEOXtwPuvI62e6Gx8RaJg7B+b79UX87IYKG65PkZe+rGut51d1ZNKRzeuzU19wFdK20StdtDt4seBAHw7vSDbf0uIXERxrXidX/kTQM3bi+iWgumZaaRdbX/un0uadG4zrXCiLxNvDZr2vkP38nmYkYEzxbP49TFq4yqSWqwv5H6MblrAyZfnW9JrxhZgQgCiziAQCJufa6Tt91ewhx5m0Tt9giyuTicM1w3aGpK4aUzl9mIi/QtE2MFzDyyBXnTaPxsOGdg5itbUF4IvgEallK5kopOpYzI20RXat+O26M4V2psbOowjSwOPLQJALTFRHb8oglWc5K0oYvgdZXQUSBA49hJbpVBIW+TdprzuIknANeS/VUrsjCyGXxUqbYI7VNHzwVao5eVi71cSJqx30urbVF6lAhaZ1YmNd1CIW+TsG4PnXiuHMq45uwWqjXgds31uLq3ASf2Rlx2os7vE9JtnMJt39wM4y33IyuCJaU877EkFstRyDsgzGaNTjx1Gy9W5sQtWnZ7G3CjppRrpM1qTtJt7MKbzxlQCk1vl0DwIOjZ4nkcOXO5ER1HKdxOakrh58sOMJ1TJonFcrQfLhNVzth5nJ0bR/Dd+Q87vvhMI4O7V61sOu6pi1cDRebOhj+6CzQpjYFI+vCKmJ0YGQEcThOd1a84V8JTR89p/d7dwGpMt3PjSOI6IbJplgdRjV+zyojtx3npzOVIIohKdanpuEfOXMbOjSMoBIgOnJE2qzlJlDjvn3Kl6mkHrC6plr/X9VmZOXGppyIO3GlMd+xsCV8cXd2o/syKYM+25iK+pDhamFpBuJyxV+R+8I0LbZURt4MCcOTMZTy+fbQlanBivQo6o6a7jAzKC62bqYSEwe3+aQfL6me9bXpVXvaCSrWGH75zrbGGmlI4draE8c/cDQCJMgxQyBFu/JrzP95TR89h9r1reGFic1dzd24oAKcuXm0aRuHcabfsU2P/4fv4+OZi40FTrlRhGll8I0VlyCSZRLm3Yr3FBmE4ZzS1sugGbq6Vg29cwK9vLrZYgN1aAPQqSGJqBcEb+LhFHgrAS2cuY9Nzf9mt5XliPUwA4PDerfjG3q2NdItd1K8vVFveFth3nERB2M0/IyMwsl59Pf15Yvso5p67Hz+ffhCrVgTrYhgV5UpVW8dhRea9Ls6jkCN4ztgr8rhxu/NXy3axXzAAcHrqPhTyZqDX0m4VUpDBYXLXBs92y8M5A3nTaHQTnHlkC2a+siXQ/o6Ol85cxmef/h6eLZ6P9d5zkhWJZVALhRzuLWXddqbjsh1lRbDjs3f7fs5+wQQVaK8WuYQEYWKsgMe3j7aIuWlkcXjvVhx4aBNWrbyTxZ1971oj9ZDp4PKzWlMEoReXuWlkY5swxBz5MkE84ZO7NvTcCgUAS0rhyL/+bfzDZ76H2z7NgawLRuSOF90Lv1J/QoLwwsRmjH/mbt/KZWcOvFeXXyffY09RZgRw8zNkRTwHp3c7CKSQh2BirIDZ9641FSf0AgVg3dTxQJ9dkzdRnCsFvnA7eb0lxI5bMLRj+mQkjpY4sd9KbiLu9JaHbd0RBRRyhCsGckYeq00Dv761iFqPbIdeWBdM0HycaWSxc+MIdkyfdD13NtYindKv1cKZ5Tde530R16CWga7sLM6V8PybF1osTNYTFgj2H0R3nF5jTTRZP3Vc+8YwnDNQXqhitWng9mKtZaqK/dyjHJxBBpNOBqYknTgmCOkqOwdWyN0m/NjJmwZuLS41/b2VKyvYSuR1key+gB0KoyQrgsfuXast3c+bBs4duN/33K10C8v4STs4+6zY6xfaZeVQBrcWm4MOqVf6u6Y7eoHfvdCNN1qW6Dvwq0YrV6qunnHgzoaNzis6MVbAcK47rTW9sHbxcysyrnYwkTsXl9e5XylX2FiLtIWzXP/6QhUQNNkPn9g+2nCIWdZEP5wiDtSP+dV7R5vcZjmjd5LmdS9E1fYjKAMr5FELktMr+uAXPhXp8cPwt7+44Zpaub5QbVxcXqzJm6GnnBMCuAdI1ZrCqpVDeHf6wUYE+38/ugkF4FeVRfzulk95+tB1XF+o4tjZEnZuHGm0nb3pIvhBMY1MqM1/r3vBq+1HNxhYIe+GIFkPh+JcCUf/+v3Ijx8FlWrN0ztubZiysRZpB783uWeL5/HSmcsN26v1FtludqRSrTXejoHO0iyV6hKu37gV6LN+90Kv32gHVsjdhKpTrEEOvWye1Q41pVzPfThnNDYzrSIpe4po5dDAXi4kIH5vci+/lcwAx8K5+e+GrmDQTq/faAf2znQTqrxpdJTbrimFfUfP9bx5VlisC9F57gce2tRycd60XdjlSpVDnYknfm9yaS9Ay5sGTk/dF6h4sJdvtAMr5BZOofr45mJLQx/TyDY2aIBkl7UL4PumsXPjCIDWc3/q6Dk8Wzzf+Fmv83wk/fi1uwhy6xTyJswebloGxcgIDj68KdBng7b9iIqBKwiyW4IyIi0RQnVJIW8aWLVyqMU25GfbSwIKwKHdm7H/lXlt9HPq4lWcunhV28nx+I8/RHmhqs1b0rlCnASx2hXnSq4DjZ0k0XcuAPbeszaUEIcZBdkpAyXkTiHWCV25UsW5A/c3/V7aChseu3ettqGQnxD7FTbRuULs6Pr07zt6rqnmIk33jxOr978dt7GOutqSbtPXQu78h164vRgomranTtIQhTuZ/M68Z9izevmNo50bi84V4kTXpx8INygi6dgDILeHl/08ez0xKHmJqIhwM+QHLaG3R+oH37iQKhEH6r5dL9eMiH8P6ZbfQffzfCSdDEqqzf4mGmS8nXM/qZszPjuOyEVkLYA/B/CbAJYAvKiU+k+dHrdTOpkjaG1qFudKiXegtEN5oRqqkyPL8okXVjFOv2PNFJ3ctSHww6tUrmD91HGsNg3cuL3YGDoddcQeRUS+CGC/Uuq3AGwH8Eci8vkIjtsR7UYJ9tRBv7ozrMjihYnNjdFwgrq1ys2xw1QK8aIbNRlxEMR6bAlwPoRNWaG+71Z1zBKI0gHWcUSulPoQwIfL//vXIvI2gAKA/93psTtBFyU4HSleGxT9+spoF2bnzjpb15Kw2Fu3ug0ATwMrhzKBe/hXqrXI0q1RaUykm50isg7AGIC3XP7uSQBPAsDo6GiUX+vK5K4Nrm1YDz7cWvSiox9fGYdzhm+/cQo3CYv9unFeV+t+w8QP37mWWHHPZgTVxSXXxlzdJioHWGRCLiKfAHAMwD6l1K+cf6+UehHAi0C9jW1U36sjigbvcY126yYHHqoXNLjtuvdyl52kG7tYrzYNVGtLjSHIedNoCph2TJ9M7D2UFcGKrKASQ0uNKNOWkQi5iBioi/gRpdRrURwzCqKILnMrsoma0t0J9mjcq2qTQk68cAYBTkNAuVLF5KvzAOr3YFJTlKaRxZ5thZ7ZI42M4BN3DaG8UI08bRmFa0UAfAvA20qpP+18Sb1Fl15Io3/c4nN/fxX+9hc3Wn5ub63LfuOkXYI4wqpLqrGR51ZB7caqHgdNGUHPupQWurzfFEVEvgPAPwNwXkTOLf/sj5VS34vg2F3FK73QiX0xLuwTgtz47vyHjY1dEffJ4qzaJH6Esd49/dr5QCL+xPZRvDCxudHmthf06qHRC/tuFK6V/wm01Rc+drzSC2mMTN859GUAwPqp465/X65UG6/BbvdWNiO4cWsR66eO07FCtAQ1AWRFglVSZwTjn7kbxbkSjiS8ClQQ/A0D6J19t69L9C106ROv9ELaHCtWt7jiXCnUhWantqQaQs/NT6LDzRHmJIPgLWtrSwrPvH4eC7drid0UtVAAlgKeV1akZ5XQfVuib+E1O8+r+bvV6jUqul0wcbO6hGeL5wO/ygaBLWuJG84WrW4tZ7NZCTSL0+JGCkQcqIvzXQFb7C4p1bMgqO+F3Ct94tX8XZdnBsL3Iy/kTXxxdHWo3wmL1YI26rx+GlNMpPtMjBVweuo+vDv9IO5etbLl76s1hWptKZ05Vw9qSqESYIoQ0Nv9pr4Xcq/0iVfzd6+0SpiI13ow/PBn18IuPRFw85P4obvH0hJld4Net7ZIdY48SDm5LtdtCZSb1zyqrmR2y9G+o+cC/162zRx3J+SMDBSkpRKWfVaIH2nbT9LR6X1XyJuxtbZIbUTulfu2E3Z2XnGuhP2vzEeyxp0bRzAxVmganxaEJaVweO/WnjUiEgD/cfcXejqaivQP/dI0qxMRF9T/Hd6dfjDQTM+oSW1EHrQyMUypvvVwiCoatvywYX2xa/JmY30H37jQ1Va6AuDx7aON76Nwk7A477F8zvAcFdiPKCDWqujUCLkzjaJ7lXPL1wUt1e9GEVA7xQ3rfqM57VOcK3Wl50tWBF9/dAvFm3SM/Vp9+rXzAyXiFnEaA1Ih5G4VmLpWmX6bc1559aQ4NE6/cw2bnvtLGNkMPqrU+zJ048bopT2KDAZprIiOijV5M7Y20KkQct1MQKeY+23O+XX884r0h3NG4FFxUVAvH76zznbw27yhI4VETVKCoW5iZACnA9E0sti5cSS2jqKp2OzUXRwKCLU555VXB9w3bQT1PhBzz93fGAGXBkwji68/ugU/n34QT2wfbfHz0pFCuoEuOCjkTe39I6i7ptzIikBQD6TCFBhFjWlkcXjvVhzeuxVD2VaN2LOtgFMXr7rqy76j5yKf0ekkFULudXFYRQlBdor9Ov5NjBWwZ1uhSfQUgGNnSyjOlRK/O29d9M6HmnOkGx0ppFt4ucR0A78VgJVGtuX3jIzg75n1pEFuxRAOPrwplmAqbxqN+8VtGLsCGs3odOhcdVGRCiEPayHU4VWSb3Hq4tWWfLQ9al85lNx/siWlYrM/EQK0lu8X8ib2bCtg5sQlzw378kK16ffypgEIcH3Z/WIJ4c6NIz0PplatHGps5OocZKXlARtedLPlRSpy5FFM+wH049/sDwTdU7VUrmDy1XlUY5gkEhTdg4rTgEgvcY59C9LX37Lc2qcKOUWzUq3h1MWrOLR7c0/ng1qa4CfCN24vwsiIp0Z0aw8hFUIORDPtJ8gDQbfhKYKORDwjQDefAW5vKNYOutv5cBoQ6QVBXCxu165XQAWg0d/bfo1bm/tR32sK9QeLn+mgWlMYzhlQqnVqkkW3DAapEfKo8Hsg6KJ2r4vRzx1izTCcfe8aXn7r/a6U3ztz3kEioUFwGJDeU5wrBSpkEyB0aw0ATW+Tuvv5s09/L9L7LGj0b41xczt3q/qzGyQ34RsTuhyfF379ia0c2wsTm/HOoS/j8N6t0S0Y9TU6L+YgkRDthyRqinMlTL467yvihbzpuZ/jZSwIkmt+7N61wRcdEMvy7EU+Z3i67JI86q3vcMvx6RjOGcitGPJ87XL+h41yw8PIuk/18Yu2aT8k3WDmxCXfFGSQa8+6/3TN5krlSiPqdpuH+cLEZrx79WOcfudO19Edn70bP/9lpaMGX5blWXcMpfRvE9103DAi98ErsjWyggMPbcLkrg0wMvpntTPyjSKlYXlrsZyPCzo0A6D9kHQPv2s7zLU3MVbwFD8rdeJm7SvOlfA3lz9q+vzpd67h2o1b8LhVfbEsz7pDfFSpRuayC8NAR+T2ctr88iaFVRIfJLL9xMohPHX0HNbkTey9Zy2Onf2gpem823/AfIdVotbFtGP6ZMtx7EMz3HL9FHDSTbxy2+0MId65cQRHzlz2zU07N+91AVjQoRBu2O9lr/bYUbnswjCwQu7cDLQLot2ep3Wx2H6nVK7g2NkSDu3+AgDv/4DFuRI+vrnY0dqtMXR+QzP81kJI1Ezu2uBq0zWyEjoiLc6VcOxsKbC90H4/RLmR77Yp62dljsJlF4aBFXK/zUCvyNZt99r6vF8xTpAcovUduiHK1hi6doZmENJN3NovD+cMHHhoU+hrMWwDLssmOLlrQ2TDLnRvEUkLlAZWyIM8sXWRbZgWuu18xrp41k8d9zyGV1QQVxc2QqIKINqJqq236T3bCjh2tqR9EDiDMSMrgGquFRHAcwh7kgKlgd3sDGK9s0e29p4uug2YMMfUYX89y+fcS37t63Kb6gMg0PQkQpKM372iG4JurwAddrmHTCOLx7ePNt03M1/Zgr33rNX2WUo6AxuRu0Wzdrx2mYOU+of5Xis6sNuodLl0Z67RLSrYMX0y0PQkQpKM1z1qbdx72RNnTlzCgYc2AQiWApk5cUmbMk36fTOwQu42nsrNtRLkd8OkLvx+tzhX8iwHXrViqOMuj4SkAes63//KfMtekSWwXlXV1pvood2bA7ll0nzfDKyQA53luLrxu0HK6j8KML/TbxOUkLQwMVbAU5qo+8py6tCLMBF1mu+bgc2RJ5GoyurjKEggpFt4tZ8OUi0ZNKJO831DIU8QUZXV6zZBk57nI8QNv2EVfv3Jg0bUfveNlfZcP3W86xN/wjLQqZWk4VcVF8ZCmCRrFCGdEGRPStefPGxEHTTtmbSe/qK60FLVj/HxcTU7O9vz7006xbkSJr8zj2rNURWXEcw8siURFwwhSaZb9RM6A0I7bQc6QUTOKqXGnT9nRJ4grJmAzhag1SWVCgsUIXHTrTfRpDtaKOQ9xi9i0LlSknLBEDKIJN3Rws3OHmLl2bwqLoMMiCaE9JakO1oo5D3EzV7onHaS9AuGkEEk6U4wplZ6SJA8W9K6qhFC6iTZCUYh7yFB82xJvmAI6Qbs1tkZkaRWROTbIvILEflJFMfrV5g2IaSVIHtHxJuocuR/BuCBiI7VtyQ9z0ZIHATZOyLeRJJaUUr9QETWRXGsfodpE0KaSbpHOw30zLUiIk+KyKyIzF69erVXX0sISTi03HZOz4RcKfWiUmpcKTU+MqIfn0QIGSy4d9Q5dK0QQmKFltvOoZATsgwtcPHBvaPOiMp++DKA/wVgg4h8ICL/KorjEtIraIEjaSYq18pjURyHkLjwssAxUiRJh71WCAEtcCTdUMgJAS1wJN1QyAkBLXAk3dC1QgYKnTOFFjiSZijkZGB4tngeR85cbgzndQ7QpQWOpBUKORkIinOlJhG3CONMoc+cJBUKOekrdGI7c+JSi4hbBHGmWD5zy6LojOYJiRNudpK+wauox0usMyK+hT9stUqSDIWc9AXFuRL2vzKvFVsvG2FNKd8qTrfJTgB95iQZUMhJ1yjOlbBj+iTWTx3HjumTXSt3tyLxmnJPnlwpV1zthXa8ouviXAmi+T3rAdGrcyXEDebISVfoZU7ZLe1hZ03ebLIXekXXbjl2XX5dUPefu53rvqPn8PybF3DgoU3MoZOuw4icdIVe5pS90hv2op6JsQJOT92HgibNks8ZLTn2ye/Ma4VfLR9T9yC5vlBl4y3SEyjkpCv0sneJLv+dFXGdiaqr4lQKLYJcrem8LsBwzgDgfU7cECW9gEJOukIve5fohPnrj25pEnErj/3U0XO4y8ggbxpNQ7A/qlRDfe/HNxdRnCv5nhM3REm3EaXZIOom4+PjanZ2tuffS3qHM28M1MXVLUKO6vu8inV069mzrYBTF6/iSrmCjIh2w9QLAbQedaAeuedWDLGQiHSMiJxVSo23/JxCTrpFkiohd0yfdM11+4lwpxhZARRQXbrzLd18oJH+RifkdK2QrpGk3iVeG5ZOsiJYUgr5nIGPby42iXBQBPU00o1biyg7UjZRDKxI0kOSxA+FnKSeIKKWDZE2WVIK704/2Dj2wTcutIixH9bvr5867vr3urx5kHNhuwDihEJOUk1QUQuT+87njEYqJswDwCIrd8qH1uRN17eBjAjWTx1vEuug58KxdMQJhZykGj+/uhXdBhVkIyv4+OYiri/UI/B2Nj+zmfoDZmKsgMldG1o2We3HtYu17lz2vzIP4I6YcywdcUIhJ6lGJ16lcgWTr8438ttBBDkrgqGMoFJd6mhNt2uqIc4AcJeRaQi02+aq9eDRnYvVCwaoi7kuyudYusGFPnKSarzEy22T0p72cFJTqmMRt6hUazj4xgU8/dr5RnQP6B0yVk7c63jWWwbH0hEnFHKSavyaYTlZUkpboh815UrVsweMHStX7nUuVsQ+MVbAod2bUcibTQVNzI8PLkytkNRiOTyCiiVwRzD3HT3XvYW1wc6NIw0h3v/KvGsqyB6xJ8naSeKHETlJJfYhEkGx0g9JFMBTF68CqAv01x/dwtQJCQUjcpJKgkbisry7mPSiGftGp7VGu3/9LoMxF9FDISepJLDVTt0pzrGTN43QRT7dxG2j89binY1XqyUuwKIf0gof8ySVBLXa6T538OFNMDJ6B0svsadNrA6N+46e44xQEhhG5CRW2u0Zoiu0sWNkBddv3MK65TJ5y8NtFQcN5wwohZ5H5nYv+XDOaEwRcuvQ6IRFP8QNCjmJjU56hvg5PABgsaaaBkNY/8v6/PWFKkwji1UrsrhxO7jzpVPsq71p860Hyfuz6Ie4wdQKiY1Ox8FNjBWw5FGxGaS4vlKt9VTE3b7fOt8g0fbOjSPdXhJJIYzISWxE0TNEV66eJqzzDXIu353/sDEII+lOHNI7GJGT2Oh0HFxxroSF24tRLikWLGthkCrVcqXaNByaw50JQCEnMRKmZ4jl5lg/dRw7pk/i2eL5lj4maaVSXcKzxfOYGCtgz7Zw0TWdLARgaoXEiJUSaGeQwpEzl7s6oq3XvHTmMl46cxntGCLpZCEUchIrQXqGuG2K6kTcGrGW1rx5Ow8nOlkIhZwkFstjHkaUMyKpFfF2YQ8WEkmOXEQeEJFLIvJTEZmK4phksGmnKRbQ3kSfNJM3DbpWSOdCLiJZAN8E8CUAnwfwmIh8vtPjksEmbHvaQURQd7HsmD5J58qAE0VEfg+AnyqlfqaUug3gLwD8XgTHJQMMN/C8sZf504ZIohDyAoD3bX/+YPlnTYjIkyIyKyKzV69ejeBrST/DDTw9edPQzv0kg0kUQu7mmGpJVCqlXlRKjSulxkdGWGZMvHHzmBsZwXDOgACN/z+I6Jp88S1mcInCtfIBgLW2P38awJUIjksGmCAec7dugW5T6gcFvsUMLlEI+Y8AfE5E1gMoAfgDAF+N4LhkQHBrZQv4Fwp5if2O6ZMDZUPkKLjBRlQEdi0R+TKAwwCyAL6tlPqa1+fHx8fV7Oxsx99L0o9bVG1kBEsAakt3rk0jK5j5yhat1c75MNi5caTvqj+dZEWwpBSbZw0QInJWKTXe8vMohDwsFHJiESZyHs4ZmHvu/pafuz0MTCOLL46uxg/fuda3Yi5wH2NH+hedkLOyk8RKmA06q0GWM/peuL3o2tf857+s4Bt7t4auDk0La/Jm2xOWSH/B7ockVsJu0NkrPq1WrroOiKVyBftfmU/FMIawDhzTyGLnxpGWfwv6yQcTCjmJlSA9uC3yphG64rOmFF46cznxF7o1SzQIWRHs2VbAy2+9zwHNBACFnMTMxFgBh3ZvRiFvQgAU8iae2D7aMuHeyAgOPryp7RTJkv9HYqemVKCHWk0pHDtb0vaVoZ988GCOnMSOWyvb8c/c7Zr79Rq2nHayIji0e7PvOWZFPN9K6CcfPCjkJJHo+pSnUcTzpoFf3axiyWfpNaUwc+KS5zmaRtZTxOknH0yYWiGx4xzj5rVZV0hhtLlq5RC+eu8o8qbh+TkBPFNHIsCh3Zu1x7EierpWBg8KOYkVNxeKl/MizOZoUiiVKzh2toSDD2/ydKf4vWsMZQSz713DDZeB00ZG8PVH9QVTpL+hkJNYcXOheDkvrM3R4Zx3dJs0KtUann/zQkf562pN4eW33ke11ir5n7hriCI+wFDISazoHBZezouJsQJyK9K3vXN9odpxYZIuf17WeOnJYEAhJ7Gii1D9ItdBtdjpvOZ0qgw2FHISK2457yDOi6QJl+WB7yamkcVj965t69+L9DcUchIrbgVBQZwXk7s2wMi2RqcZ9H7oRCFv4t3pB3F66r6uiLn93+WFic1t/XuR/obdD0lqKc6V8PybFxq9VvKmgYMPb2qIWi96khsZwcwjd9wibp0YO6GQN3F66r7An2cTrf6G3Q9J3+FWNGR50q+UK8j3wNnidItY//vgGxe0I9nc0PVgD5MycT5ELCvn7HvXcOriVYp7H8PUCkkcYQqEnL9n96RfX6gim2lNsthnfxbyJlataN+X7uYWmRgr4NyB+3F479aW4h1rOcM5A3nzzhr23rO29WYM+bKss3IeOXOZHRL7HEbkJFHookoArlGkPZUAAZyZwtqSwnDOQG7FkOfsz31Hz7W1Xq9NV12bATd2TJ9E1VHDX12ql+wHPYbOyeN8Hlg+fUbl/QOFnCQKrwIhtzRKUz5aE8GWF6quk4UsJsYKmH3vGl46cznUWsO4Rfxy1+346Z2syZuB9wQG1b7ZrzC1QhJFUEErzpWw/5X5QJuKQayKL0xsxuG9WxtuENPwvjXsfU38UkFB2hC066e342bl1Ll3kmbfJJ1BISeJIoigWcIYtBNi0Kh5YqyA01P34d3pB/H2n3wJT2wfhVv9jWlkG31Ngoh0kDYE7frpnet3WhMf3z5K3/kAQPshSRS6Qcp2r3QUA5vDrsmeFtm5caThAnHLywPNtsH1U8ddsz7O4cn271ltGhCpp4U6dZrQktg/6OyHFHKSOPyERyeMTpwPgKjWFsQnbhdp3YNH5xF3+w4jK1i1YggfVToXdpJe6CMnqcHP7RFkU6/d3tx+D5GgM0PtqaDJXRtc3zJ06Q2376jWVMOX7ufkIYMHc+QkdQTpSW5N2wnjlw6S7w7q9rCLdNg2BEG+g0OWiR1G5CR1WAI4c+KSZ2QexoO+Jm/ixq1F103J59+80Pj9IG8Dwzmj5fvCeMqD2ghpISQWjMhJKrEcJof3bvWMznWRq1v0rSupv75QbUTlk7s2eDbkMpZLN8NWpdoJOgWJFkJiQSEnqcaettDhFrkGzXXbP2993+PbR13FPGdkAKkLfyfl8M5UzHDOaDwgLGghJHaYWiGpx0pb6NwhbpFr2LTElXKlKRWTzxlQCk0uErdUT7vl8M5UDC2ExAsKOekb3NwhAmDnxpGWz+ry0KLxha82jaZjX1+owjSy+MberQ1BfUrTryWKXHaYHDsZPJhaIX3DxFgBe7YVmtIeCsCxs6WW9IaukvLxe90rIUXgW50ZRZk9Ie1AISd9xamLV7Xd/uzoLIG6CTy64cb2aDuKMntC2oGpFdJXhOkiqEtXuP1cZ3W0R9t2WyRz2aSXUMhJX6HLfVuC2+6mYdDqTOaySRwwtUL6Cq/0RpDKTR3tDon2o91pSITYYURO+gqv9MaO6ZOBh1bojt3NBlzsoULahUJO+g6d4EYxhSdKwkxDIsQLplbIwJA0e2DSHiwkvXQk5CLyiIhcEJElEWnpkUtIkkiaPTBpDxaSXjqNyH8CYDeAH0SwFkK6Src2LNslaQ8Wkl46ypErpd4GAHEbbEhIAkmSPZC+cxIV3OwkJEaS9GAh6cVXyEXkvwH4TZe/ekYp9V+DfpGIPAngSQAYHR0NvEBCCCHe+Aq5Uup3ovgipdSLAF4E6sOXozgmIYQQ2g8JIST1dGo//H0R+QDAbwM4LiInolkWIYSQoHTqWnkdwOsRrYUQQkgbiHIbh9LtLxW5CuC9Nn71kwD+LuLlJIF+PK9+PCegP8+L55QePqOUahl5FYuQt4uIzCql+q6CtB/Pqx/PCejP8+I5pR9udhJCSMqhkBNCSMpJm5C/GPcCukQ/nlc/nhPQn+fFc0o5qcqRE0IIaSVtETkhhBAHFHJCCEk5qRNyEfkTEfmxiJwTke+LyJq419QpIjIjIheXz+t1EcnHvaYo6KfBIyLygIhcEpGfishU3OuJAhH5toj8QkR+EvdaokJE1orIKRF5e/na+3dxr6kXpE7IAcwopb6glNoK4LsAnot5PVHwVwD+kVLqCwD+D4CnY15PVPTF4BERyQL4JoAvAfg8gMdE5PPxrioS/gzAA3EvImIWAexXSv0WgO0A/qhP/lt5kjohV0r9yvbHVQBSv1urlPq+Umpx+Y9nAHw6zvVEhVLqbaXUpbjXEQH3APipUupnSqnbAP4CwO/FvKaOUUr9AMC1uNcRJUqpD5VSf7P8v38N4G0Afd/wPZWDJUTkawD+OYCPAOyMeTlR8y8BHI17EaSJAoD3bX/+AMC9Ma2FBERE1gEYA/BWzEvpOokUcr9hFkqpZwA8IyJPA/g3AA70dIFtEGRAh4g8g/qr4ZFerq0Toho8knDcZhmm/k2wnxGRTwA4BmCf4y2+L0mkkIcYZvFfABxHCoTc75xE5F8A+F0A/1SlyNwf1eCRhPMBgLW2P38awJWY1kJ8EBEDdRE/opR6Le719ILU5chF5HO2Pz4M4GJca4kKEXkAwL8H8LBSaiHu9ZAWfgTgcyKyXkRWAPgDAG/EvCbigtQnwX8LwNtKqT+Nez29InWVnSJyDMAGAEuot8L9Q6VUKd5VdYaI/BTASgC/XP7RGaXUH8a4pEgQkd8H8J8BjAAoAzinlNoV66LaRES+DOAwgCyAbyulvhbvijpHRF4G8E9Qb/n6/wAcUEp9K9ZFdYiI/GMA/wPAedQ1AgD+WCn1vfhW1X1SJ+SEEEKaSV1qhRBCSDMUckIISTkUckIISTkUckIISTkUckIISTkUckIISTkUckIISTn/H1/3R7ZTt3xEAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(data[:,0], data[:,1])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9067db26",
"metadata": {},
"outputs": [],
"source": [
"def make_nf(seed,\n",
" prior_x_sampler,\n",
" prior_x_logp_fn,\n",
" posterior_x_logp_fn,\n",
" nf_module,\n",
" v_sampler,\n",
" v_logp_fn,\n",
" nf_module_kwarg_dict #this kward dict is not being passed to the function, you fucking retard\n",
" ):\n",
" \"\"\"make a single alternating nf\"\"\"\n",
" \n",
" x_seed, carrier_seed = random.split(seed)\n",
" prior_x_samples = prior_x_sampler(x_seed)\n",
" v_seed, carrier_seed = random.split(carrier_seed)\n",
" prior_v_samples = v_sampler(v_seed)\n",
" \n",
" phase = jnp.concatenate([prior_x_samples, prior_v_samples], axis=-1)\n",
" #nf = nf_module(log_scale_features=[1,1], translate_features=[1,1])\n",
" nf = nf_module(**nf_module_kwarg_dict)\n",
" \n",
" init_seed, carrier_seed = random.split(carrier_seed)\n",
" params = make_precision64(nf.init(random.PRNGKey(0), phase))\n",
" \n",
" def wrapper(params, seed):\n",
" x_seed, carrier_seed = random.split(seed)\n",
" prior_x_samples = prior_x_sampler(x_seed)\n",
" N, d = prior_x_samples.shape\n",
" v_seed, carrier_seed = random.split(carrier_seed)\n",
" prior_v_samples = v_sampler(v_seed)\n",
" \n",
" phase = jnp.concatenate([prior_x_samples, prior_v_samples], axis=-1)\n",
" \n",
" posterior_samples, logdetJs = nf.apply(params, phase)\n",
" posterior_xs, posterior_vs = posterior_samples[:,:d], posterior_samples[:,d:]\n",
" \n",
" x_works = -posterior_x_logp_fn(posterior_xs) + prior_x_logp_fn(prior_x_samples)\n",
" v_works = -v_logp_fn(posterior_vs) + v_logp_fn(prior_v_samples)\n",
" \n",
" works = x_works + v_works - logdetJs\n",
" return phase, posterior_samples, works, logdetJs\n",
" \n",
" return params, wrapper\n",
"\n",
"\n",
"def pull_out_pv(phase):\n",
" N, d = phase.shape\n",
" dim = d // 2\n",
" return phase[:,:dim], phase[:,dim:]\n",
"\n",
"def compute_loss(params, seed, wrapper):\n",
" init_samples, posterior_samples, works, _ = wrapper(params, seed)\n",
" return jnp.mean(works)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9a5410eb",
"metadata": {},
"outputs": [],
"source": [
"N = 100\n",
"dimension = 2\n",
"params, wrapper_fn = make_nf(seed = random.PRNGKey(0),\n",
" prior_x_sampler = partial(sample_normal, \n",
" N = N, \n",
" mu = jnp.zeros(dimension), cov = 0.1 * jnp.ones(dimension)),\n",
" prior_x_logp_fn = vmap(partial(unnormalized_Normal_logp, mu = jnp.zeros(dimension), cov = 0.1 * jnp.ones(dimension))),\n",
" posterior_x_logp_fn = vmap(partial(unnormalized_Normal_logp, mu = jnp.zeros(dimension) + 2., cov = 0.1 * jnp.ones(dimension))),\n",
" nf_module = AlternativeRNVP,\n",
" v_sampler = partial(sample_normal, N = N, mu = jnp.zeros(dimension), cov = jnp.ones(dimension)),\n",
" v_logp_fn = vmap(partial(unnormalized_Normal_logp, mu = jnp.zeros(dimension), cov = jnp.ones(dimension))),\n",
" nf_module_kwarg_dict = {'log_scale_features': [dimension,dimension,dimension], 'translate_features': [dimension,dimension, dimension], 'num_layers': 1})\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "768c56a6",
"metadata": {},
"outputs": [],
"source": [
"loss_fn = partial(compute_loss, wrapper = wrapper_fn)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e751e778",
"metadata": {},
"outputs": [],
"source": [
"opt_init, opt_update, get_params = optimizers.adam(step_size=5e-3)\n",
"\n",
"\n",
"def step(i, opt_state, seed):\n",
" params = get_params(opt_state)\n",
" _val, g = value_and_grad(loss_fn)(params, seed)\n",
" return _val, opt_update(i, g, opt_state)\n",
"\n",
"step = jit(step)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5be14183",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 17,
"id": "34978a38",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [00:05<00:00, 1870.30it/s]\n"
]
}
],
"source": [
"iters = int(1e4)\n",
"opt_state = opt_init(params)\n",
"seed = random.PRNGKey(78)\n",
"mean_works = []\n",
"import tqdm\n",
"\n",
"for i in tqdm.trange(iters):\n",
" seed, run_seed = random.split(seed)\n",
" _val, opt_state = step(i, opt_state, run_seed)\n",
" mean_works.append(_val)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f368ab0c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD9CAYAAABQvqc9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjO0lEQVR4nO3deXxU1f3/8deHYFCpdUP9WhaDQrHUXYp2sy5V0Uq1Vr9Kl2+tFL62D/219fetQq21Wv1C3YtSKQoiakGkiiAgOwKyhn0NBGSJbGGHQMh2vn/MJJlMZpJJZu7Mzcz7+XjwYObMved+TiCfOffcc8815xwiIpIZmqU6ABERSR4lfRGRDKKkLyKSQZT0RUQyiJK+iEgGUdIXEckgSvoiIhlESV9EJIMkLemb2R1m9rqZfWRmNyXruCIiUi2upG9mQ81st5mtCivvZmZ5ZpZvZn0AnHNjnHO9gPuAe+I5roiINI7FswyDmV0DHAGGO+cuCpZlAeuBG4ECYBHQwzm3Jvj5C8C7zrkl9dXfqlUrl5OT0+j4REQy0eLFi/c4586K9FnzeCp2zs0ys5yw4q5AvnNuE4CZjQRuN7O1QH9gYiwJHyAnJ4fc3Nx4QhQRyThmtiXaZ16M6bcGtoW8LwiWPQR8H7jLzB6ItrOZ9TazXDPLLSws9CA8EZHMFVdPPwqLUOaccwOAAfXt7JwbbGY7gO7Z2dlXJjw6EZEM5kVPvwBoG/K+DbC9IRU458Y553qfeuqpCQ1MRCTTeZH0FwEdzay9mWUD9wJjG1KBmXU3s8EHDx70IDwRkcwV75TNEcA8oJOZFZhZT+dcGfAgMAlYC4xyzq1uSL3q6YuIeCPe2Ts9opRPACY0tl4z6w5079ChQ2OrEBGRCHy5DIN6+iIi3vBl0o93TD9/9xH6frAiwVGJiDR9vkz68fb0v//ip4xYuI2350e9P0FEJCP5MuknavbO42NW1b+RiEgG8WXS15i+iIg3fJn04zXoZ7qRV0QkEl8m/XiHd27++jlVr3ceLE5UWCIiTZ4vk368wztm1cv//O69pYkKS0SkyfNl0k+ENqefBMD8TftSHImIiH+kbdL/ymknpToEERHf8WXST8SUzX53XpzAiERE0oMvk34ipmxecNaXEhiRiEh68GXSFxERb6R10m8dHNc/crwsxZGIiPhDWif9m7/+HwAcOlaa4khERPzBl0k/UWvvXNbuNACOlqinLyICPk36iVp7J3/XYQDmbdybiLBERJo8Xyb9RLnivNMBaHFCVoojERHxh7RO+hsLiwD4k5ZYFhEB0jzpX9fpLAC6X/KVFEciIuIPSUv6Zna+mQ0xs9HJOma7M06u8beISKaLK+mb2VAz221mq8LKu5lZnpnlm1kfAOfcJudcz3iO11DNs5rRonkzijR7R0QEiL+nPwzoFlpgZlnAQOAWoDPQw8w6x3mcRvtSi+YU6eYsEREgzqTvnJsFhK9d3BXID/bsS4CRwO3xHCceLZX0RUSqeDGm3xrYFvK+AGhtZmea2SDgcjPrG21nM+ttZrlmlltYWBh3MCdnZ3HkeHnc9YiIpIPmHtRpEcqcc24v8EB9OzvnBgODAbp06eLiDWbdzsOs23k43mpERNKCFz39AqBtyPs2wPaGVJCoZRhERKQmL5L+IqCjmbU3s2zgXmCsB8eJSadzTknVoUVEfCfeKZsjgHlAJzMrMLOezrky4EFgErAWGOWcW92QehO19g7A11t/uWqJZRGRTBfXmL5zrkeU8gnAhMbWa2bdge4dOnRobBVVTmjWjPKKuC8NiIikBV8uw5DInv62/UfZeagY55T4RUR8mfQTeSF3bnBZ5cLDx+OuS0SkqfNl0k9kT7/S0RLN1RcR8WXST2RP/6HrA9cFtP6OiIhPk34ie/pXtT8TgIN6Tq6IiD+TfiIdOR5I9n/6UA9SERHxZdJP5PDO3qISADbtKYq7LhGRps6XST+Rwzuaoy8iUs2XST+Rbr343KrXFfoCEJEMl/ZJv9WXWlS9fmjE0hRGIiKSer5M+l6tsjl+5Y6E1ici0tT4Mul7cXNWpbLyioTXKSLSVPgy6Sda+1Ytq153eGxiCiMREUmtjEj6r/S4vMb7A0dLUhSJiEhqZUTSv6j1qfzw0q9Uvb/sqSkpjEZEJHV8mfS9uJD793svq/E+p894xiz9ImH1i4g0Bb5M+l5cyDWr/bz23723LGH1i4g0Bb5M+sn09vwtqQ5BRCRpMirpz3n0ulplj49ZxctT13PkuJZeFpH0l1FJv83pJ7P6yZtrlb88dQPdXp6VgohERJIraUnfzFqa2Vtm9rqZ/TRZxw3XskVz/njrhbXKC/Yf4+BRrbkvIuktrqRvZkPNbLeZrQor72ZmeWaWb2Z9gsV3AqOdc72AH8Zz3Hj1vuYC3rq/a63yS5+azK/eWpSCiEREkiPenv4woFtogZllAQOBW4DOQA8z6wy0AbYFN0v5A2u/99Wz+P83frVW+dS1u9l9uDgFEYmIeC+upO+cmwXsCyvuCuQ75zY550qAkcDtQAGBxB/3cRPloRs6Rizv+sw0Xpm2Aee0FLOIpBcvkm9rqnv0EEj2rYEPgB+b2WvAOA+O2ygv3H1p5PIp6/l/I5clNxgREY95kfRr3wUFzjlX5Jz7pXPu1865d6PubNbbzHLNLLewsNCD8Gr68ZVton42bvl2ijSVU0TSiBdJvwBoG/K+DbA91p2dc4OBJ4El2dnZCQ4tsuys6D+Gnw1ZkJQYRESSwYukvwjoaGbtzSwbuBcY68FxEuaErEgnJwFLtx6gVGvwi0iaiHfK5ghgHtDJzArMrKdzrgx4EJgErAVGOedWN6ReLx+iEsmzd11K+1YtmfrwNRE/f2nK+qTEISLiNfPjDBUz6w5079ChQ68NGzYk9dhD53zOUx+vqVW+uf8PkhqHiEhjmdli51yXSJ/5YupkuGT39EPd/532Ecsnr97JkDmfJzkaEZHEap7qACIJ6emnOpQqvd9eDEDPKF8KIiJNgXr6Ebx8z2UpOa6IiNd8mfRT7Y7LW3P/tyP36LUEs4g0Zb5M+l48LrGhok3jHKpxfRFpwnyZ9FM9vFOXF6es5/M9RakOQ0SkUXyZ9P3gN9d14O4oSzRc9/xMJq7ckeSIRETi58uk74fhnVNPOoHnoizGBvDrd5ckMRoRkcTwZdL30/BO/jO3pDoEEZGE8WXS95PmWc2Y2+f6iJ/9buRSKioCdzTP27iXnQf18BUR8TfdnBWDk7OzIpaPWbadMcu20+u77Xl9dmBWT97T3WjRPPL2IiKp5suevp+GdwCym9f9Y6pM+ABvz9vidTgiIo3my6TvNydnx35C9PT4tbz52ecJXY45b+dhiktT/lhhEUkDSvoxOr9Vy5i3fXLcGm5/9TOen5THqEXb6t+hDvuLSrj55Vk8MnpFXPWIiICSfsym/8+1tDn9pJi3X7PjEK/OyOeRf1cn62Ml5fzh/eXsKyqJuZ6iksCyD4u37I89WBGRKJT0G6CZRX/CVl32FZWwbNsBRi8p4P3FBfSbsLbWNi9OzmPxln1R6/Djcw9EpOnxZdL3w81ZkTx841cbtd8Vf53CHQM/40Cwh//+4gI+WbWDSat3Vm0zYHo+P35tXq19LfhFs91H00GPlpSRv/tIqsMQkUbwZdL32+ydSndc3jqu/edt2lv1+oF3lvDfby/mb5+sizespOs1PJfvv/hp0o+758hxXpycV3VvRCTT1+1i+bYDyQtKpInxZdL3s5n/cy1v/vIbjdp37sa9tcpem7mxwfVUVLgaia/fxLV87fFPGlRHcWk5B4+WNvjYAJ/l125HMvT590oGTM9n/qbox79/WC63D/wsiVGll+0HjrH9wLFUhyEeUtJvoJxWLbmu09kJrXP4vM013r80ZT05fcaz58hxng55Xm/luP4VT0/hu8/OqCr/56ebONbAKZ13D5rHpU9NrlFWXuEYOCM/5mcGJPs6w/GyQBtL6+jpp8K2fUfZWBj/cNefxqxkcsiQXyp8q/90vtV/ekpjEG8p6fvAnz9aXfW6tLyCv08LPAy+y9NTmbiqOgk8MTaw3YGjpXwR7I2NXlzQqGOu/KLm9ZK3523m9+8t47lJeREvNPtB5YX0ighfNp/vKWJzA5a8rqhw/PPTjRwrif/+h+8+O4MbXoh/uOud+VurHssp4pWkJX0zO9/MhpjZ6GQdsynq+NjEqJ8Nn7eForBe+P+8v7zq9Wf5e2I6xhuzN1W9nrByBxf8cQKPf7Sascu3A3A0xkQYb0d/y94i9hw5HvP2lZOnIp1hXPf8TK59fmbMdf3otbn0m7iOvh/44/6HeK5DHCsppyyBNwOmyrGScro+M5U5G2L7f+yFUbnbWPWFvyaQJFpMSd/MhprZbjNbFVbezczyzCzfzPrUVYdzbpNzrmc8wfrJ1Iev4a93XJT04/70jQVVr8vDhjl++saCmH5hnh5f3ZP/zbtLatXz4dIv+HBp484gYnHwWOBawveem0mXp6fGvF9lTz8Ro0qVSXbdzsPxVxb04pT1bNh1mAmNeNZC6HWItTsONWjfr/35E3q/vZjZGwob/CVWWl7B42NWUXg49i/fhsrbeZgFm/YyN6xT8of3l/Py1PVV7/N3H2H34eP0/yR1Z5qPjF7Bba/M8fw4KwsOcs2zMzhU3LjravGItac/DOgWWmBmWcBA4BagM9DDzDqb2cVm9nHYn8QOgvtAh7NP4edXn5f04y4L6RHe9+bCWp//YfRyKioco3K3BYaKpm4gp894PljSsCT++/eW17tNQ3PvpU9OJqfPeC59cnKd9yQA7DpUzBuzN9Xo1TcL9vQbM6S/r6iEWesLG7zfX8au5oXJeVE/D/33GDgjnxtfmsVvYnzWQkWFI6fP+BpnXgDrdzX8i2j6ut38fMhCRizcxrf7T+fPH62qdx/nHFPW7OLt+Vv4xjNT+fmQBVG33V9UwprtDfsyqnTzy7O4Z/B8fvLGghrDae8vLuDlqRsixNWowyRU5Yq5U9bs4sDR2jdTFuw/Wuf+j45eQU6f8RwqLq2V2Mcu3073V+ewdd9RFn1e9++BF2JK+s65WUB4dF2B/GAPvgQYCdzunFvpnLst7M/uBMftG53OOSVlx54doVe/42Ax41Zs55HRK+j42EReCvakHh61vMEPdc/dvI8NwQRUWl7BG7M3UVJWPYxwwR8n8OHSApZurXm38NQ1u/hkVe0LkpU9fIBl26pPoXcdKuaCP05g+rpdQOBaxlX/O42nx6/lx6/NDamhekz/wNGSqtjqsutQMe37jueKv07hv4YurLWGUaTrA6GGzd3MK9PzGZUbWE7jwX8tIafPeI6XlVNR4bgjpIcefsZUn8rrNaFnXlCd9IpLy8npM77q2M9NWseCOmYuVfriwDGGBxf+c85F/TndOmBOjS+oSP+fKt352lxuHTC73mPXp6yigpw+4+k3sbrNlUNT0e597PvBSsYFhx7rM3XNLqav28XeI8fjuov96n7T2H2omF7Dc3ngnZrXWT5cWsB3/jaDf8zMj7hv7uZ9vBf8N7vkL5O55C+Ta5zRvDKt9hddMsUzpt8aCF1YpiBYFpGZnWlmg4DLzaxvHdv1NrNcM8stLGx4zyzZRvS+OtUh1PLbkcsill/0xCQeGV1/D77SXYPmceNLs9hx8Bi/eiuXp8ev5dXpNf/D/v695fzoH3NrlP0q+ItSXuE4XlZO3w9WMOyz6A+UX7btAOUVjr8He32h4/xLth6oet2sakwffjZkATe+NIui42Ws2xm9Bzp7w54aPcfwHL9+V2yzbirXPvp4RWDo5ujxct5duDWmfaM5cjzyqf3+oyUUl5ZXDRO9OHk9xaXlDJyxkXsGz2/QMf4xcyM3vjSL12dtqvVZQ4aRwp8L/buRS6vODHoPz+WZ8Wsi7VZL5Y//n59WxzNi0TbW7zrM6u2BjkDlv9HR4BIkIxZu5aERS6u2z98dfQHCXw3P5f5hudw6YHZYhyFgyppdVdeuKr0yLXA2HH5Rv3JG3JKtB3hl2oaqs87KyRPPfhL5DPCuQbVvshw4o/oLor6OhtfiWU8/0vdy1NY45/YCD9RXqXNusJntALpnZ2dfGUd8SXFGy2z633kxfT5YmepQYjIqt+Fj9d/sVz2Fb8D0yL0bgCFzPueyttU31P1l7Gq+OHCM6evqPtGr/I9U36/C5DWBMwHnHKu+CCSsXsNzI97/EE1pRQUnEfl5B845Sssd2c2bUVZeUWe9x0rL65zPfuR4GV9qUfvXyznHa59u5EeXt8Yi/goFFux7clx1Et15qJgLw+7DOFRcyiktmnPoWPSzt5w+46tePzNhLb2uOT/qttG8u2ALR4/XTrBjlm2vdYy7u7SlpKyCi1pHv6kyUr4rOl7GTS/NqlH2yaqdPPDOYsY9+J0a5Uu27ufOkE7GmqdujrgK7q5DgY7Dmu2H6PyVL1eV9xqeC8C3LziTgv3HuLTtabwwJdAL/3DpFzXqqDxxKymr4IUp6/l2x1bMzd9T530qoWezoUrLHcWl5Zx4QhYbC4tCypN/AT6epF8AtA153waI7RysHs65ccC4Ll269EpEfV67/mtpd8miwUJ/+Su9PT/6swVCh4QqL6auKIg+a+K3I6t7epWnzgCLNkcfE/35kAW1em/vzN/Cb66tfjhP82bGwWOlZDUz3p2/hX4T1zG/7w1c3W9a1HohMGvq0ranRf38oicm8eQPv84vvpVTo3xjYRHPfpLHpNW7+OlV7eo8RjTb9h3lu8/OoNM5p9DzO+1j3m/r3qMUl5Vz3pknx/Sgn9LyCh77sPa1gRuj3I1dmbg39/9B9EojJP3wfyMHfBq8/rK84EBV+ZHjZdzzz5q96NXbD7Gp8Aj3fCPyz/LNzz6P+Kzr7q/MYfvBYj4LeSreHz+s2XG7Lmw2WHmF4/nJ62uU/fLNhczIK2Rz/x+wde9RrnluBtFc+PgnXH9hzVzxh/dXcMV5p9MiK4tTTz4h6r6JFM/wziKgo5m1N7Ns4F5gbCKC8uvaO9GcfcqJ/HcjelGZrHKYBAKzXir1nxh5WYqPllX3J2bmVQ/7lZZHPj947MOVzN6wh9ywcd3isARzzpdP5NInJ3PRE5PoFzx2tIQf+kW1fNuBKP30apX3VVR6a+5mZm8orNo/fPptrHq8Hhjiydt1uMYqrvVZuHkfN700i2/1q//mq+LS8oj3gDw3aR0b6ll3qa4pjy5C1v972Bj32h2HIo7vX/TEpFr/3ncPmsej/14ZdWmO96Pcx1K5llW0L7BI7o4wbDMj+H/xvUVb60z4lcLPeg8fL6PrM9O49KnJVUNoD49axg0vzIw5roaKdcrmCGAe0MnMCsysp3OuDHgQmASsBUY551bXVU86+95Xz0p1CGlh0Ke1l6VYWccZQDTvLog83j5gen6Ns5JIMzOiCb12URTjvQxzNwYuju4vKuGJsatrDNuEvm6Igv2NWyahcqbI3qIS/jSm7uHICx//hDkR7vsYOKP+ZUNue2UO1zwbOQF+N0p5vMxg2tpdET8b9OlGfvnmwohDKbHek1KfR/8d//Dudc/PZMLKHXyw5As2FhY1eOpurGKdvdPDOXeuc+4E51wb59yQYPkE59xXnXMXOOeeSVRQfl1wTVKj+6vezZuONXlHEq0XGeonrwcudi6sYxgqWUKHxd6ZX/9F6PErGn6/QaWt+47yfu429oc9O+JwcWxnN/8KfmkPrWMCQKiNhUfo+VZuxM/6T1zHjLxCxoSN2ftR6LRYr1ay1YPRRRqpITc0Ne5JDE3bHxLwtLdNhbEtrRHLvRHjG3HTXLKFzun3apaPL9feaYo9/RPqeXi6ZDZr5AN4JDaxTL0NvRbkV5X3V4B3/2eUqRKky3mnpzoE8an1uw5XTRUUiZVX3QRfJv2mNnsH1JOT6MLnoIukki+TflMc3hERSSSv+pG+TPoiIpku2h3b8fJl0m+KwzsiIokU6Ua2RPBl0m+qwzunJ+k2ahGRxvJl0m+qIl3MfeB7F6QgEhGRyJT0PdbnlgtTHYKISBVfJn2N6YuIeMOXSb+pjukPv78rPwlZLvfF/6y9pKuISCr5Muk3VRe1PpX//dHFVe/vvKJNCqMREalNSV9EJIMo6YuIZBBfJn1dyIXlf74p1SGISBryZdJvqhdyEylZz8sUkcziy6QvIpLpPHqGipK+iIgfaZVNEZEMkhY9fTO7w8xeN7OPzExXKqO4qfM5qQ5BRFLso2XbPak35qRvZkPNbLeZrQor72ZmeWaWb2Z96qrDOTfGOdcLuA+4p1ERp6GW2VmcfUqLqvc5rVqmMBoR8YN9Rcc9qbchPf1hQLfQAjPLAgYCtwCdgR5m1tnMLjazj8P+nB2y65+C+wkw+9HrGd6za9V7PXhRRCo8Gt5pHuuGzrlZZpYTVtwVyHfObQIws5HA7c65fsBt4XVYYO3h/sBE59ySRkedZs5omc3uw8WpDkNEfOTLJ3kzbTveMf3WwLaQ9wXBsmgeAr4P3GVmD0TawMx6m1mumeUWFhbGGZ6ISNPU/ZJzPak35p5+FJFGIqKelDjnBgAD6qrQOTfYzHYA3bOzs6+MM74m45QTdTOWiFSL9FCmRIi3p18AtA153waI+5JzJt6R2/q0k7jlov9IdRgikubiTfqLgI5m1t7MsoF7gbHxBpWpa+9c0ua0wAtdyRURjzRkyuYIYB7QycwKzKync64MeBCYBKwFRjnnVscbVCb29AG+06EVANd3OrueLUVEGqchs3d6RCmfAExIWEQEevpA9w4dOiSyWt+7uM2pbO7/g1SHISJpzJfLMGRqT19ExGu+TPqZOqYvIlLJq0t7vkz66umLiHgj3nn6nmjqY/rXX3g2Xzv3lFSHISJSiy+TvnNuHDCuS5cuvVIdS2MMve8bqQ5BRCQiXw7viIiIN3yZ9HUhV0TEG75M+rqQKyLiDV8mfRGRTKdn5IqISNx8mfQ1pi8i4g1fJn2N6YuIeMOXSV9ERLyhpJ8E/e68ONUhiIgASvpJ0aNrO65qf0aqwxAR8WfST8cLuVEfHCwikkS+TPq6kCsi4g1fJn0REfGGkr6ISAZR0hcRySBJS/pm9jUzG2Rmo83s18k6roiIVIsp6ZvZUDPbbWarwsq7mVmemeWbWZ+66nDOrXXOPQD8J9Cl8SGLiEhjxdrTHwZ0Cy0wsyxgIHAL0BnoYWadzexiM/s47M/ZwX1+CMwBpiWsBSIiErOYHpfonJtlZjlhxV2BfOfcJgAzGwnc7pzrB9wWpZ6xwFgzGw/8q9FRi4hIo8TzjNzWwLaQ9wXAVdE2NrNrgTuBFsCEOrbrDfQGaNeuXRzh+YzuzhIRH4gn6Uda4j9qanPOzQRm1lepc24wMBigS5cuSpUiIgkUz+ydAqBtyPs2wPb4wglIx2UYRET8IJ6kvwjoaGbtzSwbuBcYm5iwRETEC7FO2RwBzAM6mVmBmfV0zpUBDwKTgLXAKOfc6kQEpbV3RES8EevsnR5RyidQx0XZxjKz7kD3Dh06JLpqEZGM5stlGNTTFxHxhi+Tvi7kioh4w5dJXz19ERFv+DLpq6cvIuINXyb9dOzpO92SKyI+4MukLyIi3vBl0tfwjoiIN3yZ9NNxeEdExA98mfRFRMQbSvoiIhnEl0lfY/oiIt7wZdLXmL6IiDd8mfRFRMQbSvpJ4nRvloj4gJK+iEgG8WXS14VcERFv+DLp60KuiIg3fJn0RUTEG0r6IiIZRElfRCSDJDXpm1lLM1tsZrcl87giIhIQU9I3s6FmttvMVoWVdzOzPDPLN7M+MVT1KDCqMYGKiEj8mse43TDgVWB4ZYGZZQEDgRuBAmCRmY0FsoB+YfvfD1wCrAFOjC/kpkn3ZomIH8SU9J1zs8wsJ6y4K5DvnNsEYGYjgdudc/2AWsM3ZnYd0BLoDBwzswnOuYp4ghcRkYaJtacfSWtgW8j7AuCqaBs75x4DMLP7gD3REr6Z9QZ6A7Rr1y6O8EREJFw8F3ItQlm9oxjOuWHOuY/r+Hywc66Lc67LWWedFUd4IiISLp6kXwC0DXnfBtgeXzgBWoZBRMQb8ST9RUBHM2tvZtnAvcDYxIQlIiJeiHXK5ghgHtDJzArMrKdzrgx4EJgErAVGOedWJyIorb0jIuKNWGfv9IhSPgGYkNCICAzvAN07dOiQ6KpFRDKaL5dhUE9fRMQbvkz66Xgh1+nRWSLiA75M+urpi4h4w5dJX0REvOHLpJ+OwzsiIn7gy6Sv4R0REW/4Mumrpy8i4g1fJn319EVEvOHLpC8iIt5Q0hcRySC+TPrpOKavW7NExA98mfQ1pi8i4g1fJn0REfGGkr6ISAZR0hcRySBK+iIiGcSXST8dZ++IiPiBL5O+Zu+IiHjDl0lfRES8oaSfJHpwloj4gZK+iEgGSVrSN7NrzWy2mQ0ys2uTdVwREakWU9I3s6FmttvMVoWVdzOzPDPLN7M+9VTjgCPAiUBB48IVEZF4NI9xu2HAq8DwygIzywIGAjcSSOKLzGwskAX0C9v/fmC2c+5TMzsHeBH4aXyhi4hIQ8WU9J1zs8wsJ6y4K5DvnNsEYGYjgdudc/2A2+qobj/QohGxiohInGLt6UfSGtgW8r4AuCraxmZ2J3AzcBqBs4Zo2/UGegO0a9cujvBERCRcPEnfIpRFnZjonPsA+KC+Sp1zg81sB9A9Ozv7yjjiExGRMPHM3ikA2oa8bwNsjy+cAN2RKyLijXiS/iKgo5m1N7Ns4F5gbCKC0to7IiLeMBfDraJmNgK4FmgF7AKecM4NMbNbgZcJzNgZ6px7JqHBmRUCWxq5eytgTwLDaQrU5sygNmeGeNp8nnPurEgfxJT0myIzy3XOdUl1HMmkNmcGtTkzeNVmLcMgIpJBlPRFRDJIOif9wakOIAXU5sygNmcGT9qctmP6IiJSWzr39EVEJExaJv0Grv7pW2bW1sxmmNlaM1ttZr8Nlp9hZlPMbEPw79ND9ukbbHeemd0cUn6lma0MfjbAzCLdUe0bZpZlZkvN7OPg+7Rus5mdZmajzWxd8N/7mxnQ5t8H/1+vMrMRZnZiurU50grFiWyjmbUws/eC5QsirJFWm3Murf4QuGdgI3A+kA0sBzqnOq5GtuVc4Irg61OA9UBn4FmgT7C8D/C34OvOwfa2ANoHfw5Zwc8WAt8ksHzGROCWVLevnrY/DPwL+Dj4Pq3bDLwF/Cr4OpvAGlVp22YCa3d9DpwUfD8KuC/d2gxcA1wBrAopS1gbgd8Ag4Kv7wXeqzemVP9QPPghfxOYFPK+L9A31XElqG0fEVjKOg84N1h2LpAXqa3ApODP41xgXUh5D+CfqW5PHe1sA0wDrqc66adtm4EvBxOghZWnc5srF2w8g8AaYB8DN6Vjm4GcsKSfsDZWbhN83ZzAzVxWVzzpOLwTafXP1imKJWGCp22XAwuAc5xzOwCCf58d3Cxa21tT88E1fv+ZvAw8AlSElKVzm88HCoE3g0Nab5hZS9K4zc65L4Dnga3ADuCgc24yadzmEIlsY9U+zrky4CBwZl0HT8ek36DVP5sCM/sS8G/gd865Q3VtGqHM1VHuO2Z2G7DbObc41l0ilDWpNhPooV0BvOacuxwoInDaH02Tb3NwHPt2AsMYXwFamtnP6tolQlmTanMMGtPGBrc/HZO+Z6t/poKZnUAg4b/rAstTA+wys3ODn58L7A6WR2t7QfB1eLkffRv4oZltBkYC15vZO6R3mwuAAufcguD70QS+BNK5zd8HPnfOFTrnSgksu/4t0rvNlRLZxqp9zKw5cCqwr66Dp2PS92z1z2QLXqEfAqx1zr0Y8tFY4BfB178gMNZfWX5v8Ip+e6AjsDB4CnnYzK4O1vlfIfv4inOur3OujXMuh8C/3XTn3M9I7zbvBLaZWadg0Q3AGtK4zQSGda42s5ODsd4ArCW921wpkW0MresuAr8vdZ/ppPoih0cXTm4lMNNlI/BYquOJox3fIXCqtgJYFvxzK4Exu2nAhuDfZ4Ts81iw3XmEzGIAugCrgp+9Sj0Xe/zwh8DKrpUXctO6zcBlQG7w33oMcHoGtPlJYF0w3rcJzFpJqzYDIwhcsygl0Cvvmcg2AicC7wP5BGb4nF9fTLojV0Qkg6Tj8I6IiEShpC8ikkGU9EVEMoiSvohIBlHSFxHJIEr6IiIZRElfRCSDKOmLiGSQ/wMcatbLM10ZaQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(mean_works)\n",
"plt.yscale('log')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbdd1d72",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 19,
"id": "20499dac",
"metadata": {},
"outputs": [],
"source": [
"optimized_params = get_params(opt_state)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5b558c08",
"metadata": {},
"outputs": [],
"source": [
"inits, finals, works, logdetJs = wrapper_fn(optimized_params, random.PRNGKey(433))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "be0b59bf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 1., 0., 0., 0., 1., 4., 17., 29., 38., 10.]),\n",
" array([-1.37765787, -1.19513886, -1.01261984, -0.83010083, -0.64758181,\n",
" -0.4650628 , -0.28254379, -0.10002477, 0.08249424, 0.26501326,\n",
" 0.44753227]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPXklEQVR4nO3df6zddX3H8edrBaJTMst6wTul3s0QIzOzmJuOiX+gqIGaDPjDRLKwJiOpJiORxP3RuWTD+A8uosmSjaQIsTPOhUQcDeC0azTEqLgLKaVNUcRUB3Tt9Sfwjwvw3h/n2+Wu3B/f8+ve9uPzkZyc7/n+uN8X3/Pl1XO+53u+J1WFJKkNv7XRASRJk2OpS1JDLHVJaoilLkkNsdQlqSHnrOfKtmzZUnNzc+u5Skk66z3yyCM/raqZPvOua6nPzc2xsLCwnquUpLNekh/3ndfDL5LUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1JB1/UapJAHM7X5gQ9Z77LYPbMh615Ov1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUkDVLPcmrknwvyWNJjiT5RDf+1iTPJDnY3XZMP64kaTV9rtL4a+A9VfVCknOBbyX5ajfts1X16enFkyQNY81Sr6oCXugentvdapqhJEmj6XVMPcmmJAeBk8D+qnq4m3RzkkNJ7k6yeYVldyVZSLKwuLg4mdSSpGX1KvWqeqmqtgFvBLYneRtwB/BmYBtwHLh9hWX3VNV8Vc3PzMxMJLQkaXlDnf1SVb8EvglcXVUnurJ/GbgT2D75eJKkYfQ5+2Umyeu64VcD7wWeSDK7ZLbrgcNTSShJ6q3P2S+zwN4kmxj8I3BPVd2f5AtJtjH40PQY8OGppZQk9dLn7JdDwGXLjL9xKokkSSPzG6WS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhrS54enX5Xke0keS3IkySe68Rck2Z/kye5+8/TjSpJW0+eV+q+B91TV24FtwNVJLgd2Aweq6hLgQPdYkrSB1iz1Gnihe3hudyvgWmBvN34vcN00AkqS+ut1TD3JpiQHgZPA/qp6GLioqo4DdPcXrrDsriQLSRYWFxcnFFuStJxepV5VL1XVNuCNwPYkb+u7gqraU1XzVTU/MzMzYkxJUh9Dnf1SVb8EvglcDZxIMgvQ3Z+cdDhJ0nD6nP0yk+R13fCrgfcCTwD7gJ3dbDuB+6aUUZLU0zk95pkF9ibZxOAfgXuq6v4k3wHuSXIT8BPgg1PMKUnqYc1Sr6pDwGXLjP8ZcNU0QkmSRtPnlbqkBs3tfmCjI2gKvEyAJDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNaTPD09fnOQbSY4mOZLko934W5M8k+Rgd9sx/biSpNX0+Tm7F4GPVdWjSc4HHkmyv5v22ar69PTiSZKG0eeHp48Dx7vh55McBd4w7WCSpOENdUw9yRxwGfBwN+rmJIeS3J1k86TDSZKG07vUk7wW+DJwS1U9B9wBvBnYxuCV/O0rLLcryUKShcXFxfETS5JW1KvUk5zLoNC/WFX3AlTViap6qapeBu4Eti+3bFXtqar5qpqfmZmZVG5J0jL6nP0S4C7gaFV9Zsn42SWzXQ8cnnw8SdIw+pz9cgVwI/B4koPduI8DNyTZBhRwDPjwFPJJkobQ5+yXbwFZZtKDk48jSRqH3yiVpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGrJmqSe5OMk3khxNciTJR7vxFyTZn+TJ7n7z9ONKklbT55X6i8DHquqtwOXAXya5FNgNHKiqS4AD3WNJ0gZas9Sr6nhVPdoNPw8cBd4AXAvs7WbbC1w3pYySpJ6GOqaeZA64DHgYuKiqjsOg+IELV1hmV5KFJAuLi4tjxpUkraZ3qSd5LfBl4Jaqeq7vclW1p6rmq2p+ZmZmlIySpJ56lXqScxkU+her6t5u9Ikks930WeDkdCJKkvrqc/ZLgLuAo1X1mSWT9gE7u+GdwH2TjydJGsY5Pea5ArgReDzJwW7cx4HbgHuS3AT8BPjgVBJKknpbs9Sr6ltAVph81WTjSJLG4TdKJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDWkz7VfJE3R3O4HNjqCGuIrdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGtLnh6fvTnIyyeEl425N8kySg91tx3RjSpL66PNK/fPA1cuM/2xVbetuD042liRpFGuWelU9BPx8HbJIksY0zjH1m5Mc6g7PbF5ppiS7kiwkWVhcXBxjdZKktYxa6ncAbwa2AceB21easar2VNV8Vc3PzMyMuDpJUh8jlXpVnaiql6rqZeBOYPtkY0mSRjFSqSeZXfLweuDwSvNKktbPmpfeTfIl4EpgS5Kngb8DrkyyDSjgGPDh6UWUJPW1ZqlX1Q3LjL5rClkkSWPyG6WS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhqyZqknuTvJySSHl4y7IMn+JE9295unG1OS1EefV+qfB64+bdxu4EBVXQIc6B5LkjbYmqVeVQ8BPz9t9LXA3m54L3DdZGNJkkZxzojLXVRVxwGq6niSC1eaMckuYBfA1q1bR1ydJI1vbvcDG7buY7d9YF3WM/UPSqtqT1XNV9X8zMzMtFcnSb/RRi31E0lmAbr7k5OLJEka1ailvg/Y2Q3vBO6bTBxJ0jj6nNL4JeA7wFuSPJ3kJuA24H1JngTe1z2WJG2wNT8oraobVph01YSzSJLG5DdKJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1ZM2fs1tNkmPA88BLwItVNT+JUJKk0YxV6p13V9VPJ/B3JElj8vCLJDVk3FIv4OtJHkmyaxKBJEmjG/fwyxVV9WySC4H9SZ6oqoeWztCV/S6ArVu3jrk6SdJqxnqlXlXPdvcnga8A25eZZ09VzVfV/MzMzDirkyStYeRST/KaJOefGgbeDxyeVDBJ0vDGOfxyEfCVJKf+zr9U1b9PJJUkaSQjl3pV/Qh4+wSzSJLGNInz1KWz3tzuBzY6gjQRnqcuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQ86a66lv5PWuj932gQ1b928ar2sujcdX6pLUEEtdkhoyVqknuTrJ95P8MMnuSYWSJI1m5FJPsgn4R+Aa4FLghiSXTiqYJGl447xS3w78sKp+VFX/A/wrcO1kYkmSRjHO2S9vAP5ryeOngT8+faYku4Bd3cMXknx/jHVO2hbgp2vNlE+tQ5LV9cp5BjDn5JwNGcGcvfXokdUyvqnvesYp9Swzrl4xomoPsGeM9UxNkoWqmt/oHGsx52SdDTnPhoxgzkmaVMZxDr88DVy85PEbgWfHiyNJGsc4pf6fwCVJfj/JecCHgH2TiSVJGsXIh1+q6sUkNwNfAzYBd1fVkYklWx9n5GGhZZhzss6GnGdDRjDnJE0kY6pecRhcknSW8hulktQQS12SGtJ8qSf5YJIjSV5OsuzpQkkuTvKNJEe7eT+6ZNqtSZ5JcrC77dionN18y16aIckFSfYnebK73zylnGuuJ8lblmyvg0meS3JLN23q27PvtkhyLMnjXY6FYZdfj5wbtW+udQmQDPxDN/1Qknf0XXaSeuT8sy7foSTfTvL2JdOWff43KOeVSX615Ln8277LvkJVNX0D3gq8BfgmML/CPLPAO7rh84EfAJd2j28F/uoMybkJeAr4A+A84LElOf8e2N0N7wY+NaWcQ62ny/zfwJvWa3v2zQgcA7aM+984zZwbsW+utp8tmWcH8FUG31e5HHi477LrnPOdwOZu+JpTOVd7/jco55XA/aMse/qt+VfqVXW0qlb9FmtVHa+qR7vh54GjDL4xu2765GT1SzNcC+zthvcC100l6PDruQp4qqp+PKU8yxl3W5wx23KD9s0+lwC5FvjnGvgu8Loksz2XXbecVfXtqvpF9/C7DL5Ps97G2SZDL9t8qQ8ryRxwGfDwktE3d2/f7p7WW/Gelrs0w6n/wS+qquMwKALgwillGHY9HwK+dNq4aW/PvhkL+HqSRzK4nMWwy69XTmBd983V9rO15umz7KQMu66bGLy7OGWl53/S+ub8kySPJflqkj8cctn/c9b88tFqkvwH8PplJv1NVd03xN95LfBl4Jaqeq4bfQfwSQY7wCeB24G/2KCcvS7NMK7Vcg75d84D/hT46yWjJ7I9J5Txiqp6NsmFwP4kT1TVQ8NmWc0Et+VU983TV7fMuNP3s5XmWZd9dI0Mr5wxeTeDUn/XktFTf/6HyPkog0OUL3SfjfwbcEnPZf+fJkq9qt477t9Ici6D/2m+WFX3LvnbJ5bMcydw/6jrmEDO1S7NcCLJbFUd794Gnxx1JavlTDLMeq4BHl26DSe1PSeRsaqe7e5PJvkKg7e6D3GGbcv12DdP0+cSICvNc16PZSel16VKkvwR8Dngmqr62anxqzz/655zyT/UVNWDSf4pyZY+y57Owy8MPskH7gKOVtVnTps2u+Th9cDh9cx2mtUuzbAP2NkN7wR6v0MZ0jDruYHTDr2s0/ZcM2OS1yQ5/9Qw8P4lWc6YbblB+2afS4DsA/68OwvmcuBX3SGk9bx8yJrrSrIVuBe4sap+sGT8as//RuR8ffdck2Q7g27+WZ9lX2Han/xu9I3Bzv408GvgBPC1bvzvAQ92w+9i8JbmEHCwu+3opn0BeLybtg+Y3aic3eMdDM6AeIrBYZtT438XOAA82d1fMKWcy65nmZy/3e2Uv3Pa8lPfnn0yMjib4LHuduRM3ZYbtW8ut58BHwE+0g2HwY/kPNVlmF9t2WndeuT8HPCLJdtuYa3nf4Ny3tzleIzBB7rvHHV7epkASWqIh18kqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWrI/wJwsQN2EguW7gAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(works)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "e3d17225",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 1., 0., 0., 0., 1., 4., 17., 29., 38., 10.]),\n",
" array([-1.37765787, -1.19513886, -1.01261984, -0.83010083, -0.64758181,\n",
" -0.4650628 , -0.28254379, -0.10002477, 0.08249424, 0.26501326,\n",
" 0.44753227]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPXklEQVR4nO3df6zddX3H8edrBaJTMst6wTul3s0QIzOzmJuOiX+gqIGaDPjDRLKwJiOpJiORxP3RuWTD+A8uosmSjaQIsTPOhUQcDeC0azTEqLgLKaVNUcRUB3Tt9Sfwjwvw3h/n2+Wu3B/f8+ve9uPzkZyc7/n+uN8X3/Pl1XO+53u+J1WFJKkNv7XRASRJk2OpS1JDLHVJaoilLkkNsdQlqSHnrOfKtmzZUnNzc+u5Skk66z3yyCM/raqZPvOua6nPzc2xsLCwnquUpLNekh/3ndfDL5LUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1JB1/UapJAHM7X5gQ9Z77LYPbMh615Ov1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUkDVLPcmrknwvyWNJjiT5RDf+1iTPJDnY3XZMP64kaTV9rtL4a+A9VfVCknOBbyX5ajfts1X16enFkyQNY81Sr6oCXugentvdapqhJEmj6XVMPcmmJAeBk8D+qnq4m3RzkkNJ7k6yeYVldyVZSLKwuLg4mdSSpGX1KvWqeqmqtgFvBLYneRtwB/BmYBtwHLh9hWX3VNV8Vc3PzMxMJLQkaXlDnf1SVb8EvglcXVUnurJ/GbgT2D75eJKkYfQ5+2Umyeu64VcD7wWeSDK7ZLbrgcNTSShJ6q3P2S+zwN4kmxj8I3BPVd2f5AtJtjH40PQY8OGppZQk9dLn7JdDwGXLjL9xKokkSSPzG6WS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhrS54enX5Xke0keS3IkySe68Rck2Z/kye5+8/TjSpJW0+eV+q+B91TV24FtwNVJLgd2Aweq6hLgQPdYkrSB1iz1Gnihe3hudyvgWmBvN34vcN00AkqS+ut1TD3JpiQHgZPA/qp6GLioqo4DdPcXrrDsriQLSRYWFxcnFFuStJxepV5VL1XVNuCNwPYkb+u7gqraU1XzVTU/MzMzYkxJUh9Dnf1SVb8EvglcDZxIMgvQ3Z+cdDhJ0nD6nP0yk+R13fCrgfcCTwD7gJ3dbDuB+6aUUZLU0zk95pkF9ibZxOAfgXuq6v4k3wHuSXIT8BPgg1PMKUnqYc1Sr6pDwGXLjP8ZcNU0QkmSRtPnlbqkBs3tfmCjI2gKvEyAJDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNaTPD09fnOQbSY4mOZLko934W5M8k+Rgd9sx/biSpNX0+Tm7F4GPVdWjSc4HHkmyv5v22ar69PTiSZKG0eeHp48Dx7vh55McBd4w7WCSpOENdUw9yRxwGfBwN+rmJIeS3J1k86TDSZKG07vUk7wW+DJwS1U9B9wBvBnYxuCV/O0rLLcryUKShcXFxfETS5JW1KvUk5zLoNC/WFX3AlTViap6qapeBu4Eti+3bFXtqar5qpqfmZmZVG5J0jL6nP0S4C7gaFV9Zsn42SWzXQ8cnnw8SdIw+pz9cgVwI/B4koPduI8DNyTZBhRwDPjwFPJJkobQ5+yXbwFZZtKDk48jSRqH3yiVpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGrJmqSe5OMk3khxNciTJR7vxFyTZn+TJ7n7z9ONKklbT55X6i8DHquqtwOXAXya5FNgNHKiqS4AD3WNJ0gZas9Sr6nhVPdoNPw8cBd4AXAvs7WbbC1w3pYySpJ6GOqaeZA64DHgYuKiqjsOg+IELV1hmV5KFJAuLi4tjxpUkraZ3qSd5LfBl4Jaqeq7vclW1p6rmq2p+ZmZmlIySpJ56lXqScxkU+her6t5u9Ikks930WeDkdCJKkvrqc/ZLgLuAo1X1mSWT9gE7u+GdwH2TjydJGsY5Pea5ArgReDzJwW7cx4HbgHuS3AT8BPjgVBJKknpbs9Sr6ltAVph81WTjSJLG4TdKJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDWkz7VfJE3R3O4HNjqCGuIrdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGtLnh6fvTnIyyeEl425N8kySg91tx3RjSpL66PNK/fPA1cuM/2xVbetuD042liRpFGuWelU9BPx8HbJIksY0zjH1m5Mc6g7PbF5ppiS7kiwkWVhcXBxjdZKktYxa6ncAbwa2AceB21easar2VNV8Vc3PzMyMuDpJUh8jlXpVnaiql6rqZeBOYPtkY0mSRjFSqSeZXfLweuDwSvNKktbPmpfeTfIl4EpgS5Kngb8DrkyyDSjgGPDh6UWUJPW1ZqlX1Q3LjL5rClkkSWPyG6WS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhqyZqknuTvJySSHl4y7IMn+JE9295unG1OS1EefV+qfB64+bdxu4EBVXQIc6B5LkjbYmqVeVQ8BPz9t9LXA3m54L3DdZGNJkkZxzojLXVRVxwGq6niSC1eaMckuYBfA1q1bR1ydJI1vbvcDG7buY7d9YF3WM/UPSqtqT1XNV9X8zMzMtFcnSb/RRi31E0lmAbr7k5OLJEka1ailvg/Y2Q3vBO6bTBxJ0jj6nNL4JeA7wFuSPJ3kJuA24H1JngTe1z2WJG2wNT8oraobVph01YSzSJLG5DdKJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1ZM2fs1tNkmPA88BLwItVNT+JUJKk0YxV6p13V9VPJ/B3JElj8vCLJDVk3FIv4OtJHkmyaxKBJEmjG/fwyxVV9WySC4H9SZ6oqoeWztCV/S6ArVu3jrk6SdJqxnqlXlXPdvcnga8A25eZZ09VzVfV/MzMzDirkyStYeRST/KaJOefGgbeDxyeVDBJ0vDGOfxyEfCVJKf+zr9U1b9PJJUkaSQjl3pV/Qh4+wSzSJLGNInz1KWz3tzuBzY6gjQRnqcuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQ86a66lv5PWuj932gQ1b928ar2sujcdX6pLUEEtdkhoyVqknuTrJ95P8MMnuSYWSJI1m5FJPsgn4R+Aa4FLghiSXTiqYJGl447xS3w78sKp+VFX/A/wrcO1kYkmSRjHO2S9vAP5ryeOngT8+faYku4Bd3cMXknx/jHVO2hbgp2vNlE+tQ5LV9cp5BjDn5JwNGcGcvfXokdUyvqnvesYp9Swzrl4xomoPsGeM9UxNkoWqmt/oHGsx52SdDTnPhoxgzkmaVMZxDr88DVy85PEbgWfHiyNJGsc4pf6fwCVJfj/JecCHgH2TiSVJGsXIh1+q6sUkNwNfAzYBd1fVkYklWx9n5GGhZZhzss6GnGdDRjDnJE0kY6pecRhcknSW8hulktQQS12SGtJ8qSf5YJIjSV5OsuzpQkkuTvKNJEe7eT+6ZNqtSZ5JcrC77dionN18y16aIckFSfYnebK73zylnGuuJ8lblmyvg0meS3JLN23q27PvtkhyLMnjXY6FYZdfj5wbtW+udQmQDPxDN/1Qknf0XXaSeuT8sy7foSTfTvL2JdOWff43KOeVSX615Ln8277LvkJVNX0D3gq8BfgmML/CPLPAO7rh84EfAJd2j28F/uoMybkJeAr4A+A84LElOf8e2N0N7wY+NaWcQ62ny/zfwJvWa3v2zQgcA7aM+984zZwbsW+utp8tmWcH8FUG31e5HHi477LrnPOdwOZu+JpTOVd7/jco55XA/aMse/qt+VfqVXW0qlb9FmtVHa+qR7vh54GjDL4xu2765GT1SzNcC+zthvcC100l6PDruQp4qqp+PKU8yxl3W5wx23KD9s0+lwC5FvjnGvgu8Loksz2XXbecVfXtqvpF9/C7DL5Ps97G2SZDL9t8qQ8ryRxwGfDwktE3d2/f7p7WW/Gelrs0w6n/wS+qquMwKALgwillGHY9HwK+dNq4aW/PvhkL+HqSRzK4nMWwy69XTmBd983V9rO15umz7KQMu66bGLy7OGWl53/S+ub8kySPJflqkj8cctn/c9b88tFqkvwH8PplJv1NVd03xN95LfBl4Jaqeq4bfQfwSQY7wCeB24G/2KCcvS7NMK7Vcg75d84D/hT46yWjJ7I9J5Txiqp6NsmFwP4kT1TVQ8NmWc0Et+VU983TV7fMuNP3s5XmWZd9dI0Mr5wxeTeDUn/XktFTf/6HyPkog0OUL3SfjfwbcEnPZf+fJkq9qt477t9Ici6D/2m+WFX3LvnbJ5bMcydw/6jrmEDO1S7NcCLJbFUd794Gnxx1JavlTDLMeq4BHl26DSe1PSeRsaqe7e5PJvkKg7e6D3GGbcv12DdP0+cSICvNc16PZSel16VKkvwR8Dngmqr62anxqzz/655zyT/UVNWDSf4pyZY+y57Owy8MPskH7gKOVtVnTps2u+Th9cDh9cx2mtUuzbAP2NkN7wR6v0MZ0jDruYHTDr2s0/ZcM2OS1yQ5/9Qw8P4lWc6YbblB+2afS4DsA/68OwvmcuBX3SGk9bx8yJrrSrIVuBe4sap+sGT8as//RuR8ffdck2Q7g27+WZ9lX2Han/xu9I3Bzv408GvgBPC1bvzvAQ92w+9i8JbmEHCwu+3opn0BeLybtg+Y3aic3eMdDM6AeIrBYZtT438XOAA82d1fMKWcy65nmZy/3e2Uv3Pa8lPfnn0yMjib4LHuduRM3ZYbtW8ut58BHwE+0g2HwY/kPNVlmF9t2WndeuT8HPCLJdtuYa3nf4Ny3tzleIzBB7rvHHV7epkASWqIh18kqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWrI/wJwsQN2EguW7gAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(works)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7af5751b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 1., 1., 5., 13., 12., 30., 15., 11., 7., 5.]),\n",
" array([-0.43851942, -0.36744703, -0.29637464, -0.22530225, -0.15422986,\n",
" -0.08315747, -0.01208508, 0.05898731, 0.1300597 , 0.20113209,\n",
" 0.27220448]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOQElEQVR4nO3dfahk9X3H8fcnPmCIBrVezVbd3sZIWwnJKrdWahBTk+JDqfpHoNKahQqbQAIJTWi3CbQJpWBKY0ohhG6iZEuNRVBR1D7YbYKEGJOrWXVlTYzBpiaLu0at+k9a9ds/5thernd3zp3Huz/fLxjmnDO/mfNx1vvZc8+c+W2qCklSG9407wCSpMmx1CWpIZa6JDXEUpekhljqktSQI2e5s5NOOqkWFxdnuUtJOuw98MADz1TVQp+xMy31xcVFlpeXZ7lLSTrsJfmPvmM9/SJJDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaMrTUkxyT5DtJHkryaJLPdttPTHJPkse7+xOmH1eSdCh9jtR/DvxWVb0b2AJcnOQ8YDuwq6rOBHZ165KkORpa6jXwUrd6VHcr4HJgZ7d9J3DFNAJKkvrr9Y3SJEcADwDvAL5YVfcnOaWq9gFU1b4kJx/kuduAbQCbN2+eTGppwha33zW3fT957WVz27fa0+uD0qp6paq2AKcB5yZ5Z98dVNWOqlqqqqWFhV5TF0iSRrSuq1+q6nngG8DFwNNJNgF09/snHU6StD59rn5ZSHJ8t/xm4H3AY8AdwNZu2Fbg9illlCT11Oec+iZgZ3de/U3AzVV1Z5L7gJuTXAP8GPjAFHNKknoYWupV9TBw9hrbfwZcNI1QkqTR+I1SSWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWrI0FJPcnqSryfZm+TRJB/rtn8myU+S7O5ul04/riTpUI7sMeZl4BNV9WCS44AHktzTPfaFqvrr6cWTJK3H0FKvqn3Avm75xSR7gVOnHUyStH7rOqeeZBE4G7i/2/TRJA8nuSHJCQd5zrYky0mWDxw4MF5aSdIh9S71JMcCtwAfr6oXgC8BZwBbGBzJf36t51XVjqpaqqqlhYWF8RNLkg6qV6knOYpBod9YVbcCVNXTVfVKVb0KfBk4d3oxJUl99Ln6JcD1wN6qum7F9k0rhl0J7Jl8PEnSevS5+uV84GrgkSS7u22fAq5KsgUo4EngQ1PIJ0lahz5Xv3wTyBoP3T35OJKkcfiNUklqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ0ZWupJTk/y9SR7kzya5GPd9hOT3JPk8e7+hOnHlSQdSp8j9ZeBT1TVrwHnAR9JchawHdhVVWcCu7p1SdIcDS31qtpXVQ92yy8Ce4FTgcuBnd2wncAVU8ooSeppXefUkywCZwP3A6dU1T4YFD9w8kGesy3JcpLlAwcOjBlXknQovUs9ybHALcDHq+qFvs+rqh1VtVRVSwsLC6NklCT11KvUkxzFoNBvrKpbu81PJ9nUPb4J2D+diJKkvvpc/RLgemBvVV234qE7gK3d8lbg9snHkyStx5E9xpwPXA08kmR3t+1TwLXAzUmuAX4MfGAqCSVJvQ0t9ar6JpCDPHzRZONIksbhN0olqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSFDSz3JDUn2J9mzYttnkvwkye7udul0Y0qS+uhzpP5V4OI1tn+hqrZ0t7snG0uSNIqhpV5V9wLPziCLJGlMR47x3I8m+SCwDHyiqp5ba1CSbcA2gM2bN4+xO6lNi9vvmst+n7z2srnsV9M16gelXwLOALYA+4DPH2xgVe2oqqWqWlpYWBhxd5KkPkYq9ap6uqpeqapXgS8D5042liRpFCOVepJNK1avBPYcbKwkaXaGnlNPchNwIXBSkqeAPwcuTLIFKOBJ4EPTiyhJ6mtoqVfVVWtsvn4KWSRJY/IbpZLUEEtdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGjDNLoxo1r1kDwZkDpXF5pC5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIa4iWN2lDmeTml1AKP1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDvKRReoNyNs42eaQuSQ2x1CWpIZa6JDVkaKknuSHJ/iR7Vmw7Mck9SR7v7k+YbkxJUh99jtS/Cly8att2YFdVnQns6tYlSXM2tNSr6l7g2VWbLwd2dss7gSsmG0uSNIpRz6mfUlX7ALr7kw82MMm2JMtJlg8cODDi7iRJfUz9g9Kq2lFVS1W1tLCwMO3dSdIb2qil/nSSTQDd/f7JRZIkjWrUUr8D2NotbwVun0wcSdI4+lzSeBNwH/ArSZ5Kcg1wLfD+JI8D7+/WJUlzNnTul6q66iAPXTThLJKkMfmNUklqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUkKFT70rSpC1uv2su+33y2svmst9Z8khdkhpiqUtSQyx1SWqIpS5JDbHUJakhlrokNcRSl6SGWOqS1BBLXZIaYqlLUkPGmiYgyZPAi8ArwMtVtTSJUJKk0Uxi7pf3VtUzE3gdSdKYPP0iSQ0Z90i9gH9NUsDfVdWO1QOSbAO2AWzevHnM3b2xzGsmO6lV8/yZmtUMkeMeqZ9fVecAlwAfSXLB6gFVtaOqlqpqaWFhYczdSZIOZaxSr6qfdvf7gduAcycRSpI0mpFLPclbkhz32jLw28CeSQWTJK3fOOfUTwFuS/La63ytqv55IqkkSSMZudSr6kfAuyeYRZI0Ji9plKSGWOqS1BBLXZIaYqlLUkMsdUlqiKUuSQ2x1CWpIZa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJaoilLkkNsdQlqSGWuiQ1xFKXpIZY6pLUEEtdkhpiqUtSQyx1SWqIpS5JDTly3gH6Wtx+17wjSNKG55G6JDXEUpekhljqktSQsUo9ycVJvp/kh0m2TyqUJGk0I5d6kiOALwKXAGcBVyU5a1LBJEnrN86R+rnAD6vqR1X138A/ApdPJpYkaRTjXNJ4KvCfK9afAn5j9aAk24Bt3epLSb4/xj7HdRLwzBz339fhkhPMOi1mnY65Zc3n1jV8dc5f6vvEcUo9a2yr122o2gHsGGM/E5NkuaqW5p1jmMMlJ5h1Wsw6HYdL1nFyjnP65Sng9BXrpwE/HeP1JEljGqfUvwucmeSXkxwN/B5wx2RiSZJGMfLpl6p6OclHgX8BjgBuqKpHJ5ZsOjbEaaAeDpecYNZpMet0HC5ZR86ZqtedBpckHab8RqkkNcRSl6SGNF3qSU5Mck+Sx7v7Ew4x9ogk30ty5ywzdvsemjPJMUm+k+ShJI8m+eysc64j6+lJvp5kb5f1Yxs1azfuhiT7k+yZQ8ZDTrWRgb/tHn84yTmzztgz568muS/Jz5N8ch4ZV2QZlvX3u/fy4STfSvLueeTssgzLenmXc3eS5STvGfqiVdXsDfgrYHu3vB343CHG/hHwNeDOjZiTwfcCju2WjwLuB87boFk3Aed0y8cBPwDO2ohZu8cuAM4B9sw43xHAE8DbgaOBh1a/T8ClwD91f/7nAffP4X3sk/Nk4NeBvwQ+OeuM68z6m8AJ3fIl83hP15H1WP7/s893AY8Ne92mj9QZTFuws1veCVyx1qAkpwGXAV+ZTazXGZqzBl7qVo/qbvP4lLtP1n1V9WC3/CKwl8E3kGet159/Vd0LPDujTCv1mWrjcuDvuz//bwPHJ9m00XJW1f6q+i7wPzPOtlqfrN+qque61W8z+I7NPPTJ+lJ1jQ68hR4/862X+ilVtQ8GRcPgaGItfwP8MfDqjHKt1itnd4poN7AfuKeq7p9dxP/T9z0FIMkicDaD3yxmbV1Z52CtqTZW/+XXZ8y0bYQMfa036zUMfhOah15Zk1yZ5DHgLuAPh73oYfPP2R1Mkn8D3rbGQ5/u+fzfAfZX1QNJLpxgtNX7GSsnQFW9AmxJcjxwW5J3VtXEzwNPImv3OscCtwAfr6oXJpFtjX1MJOuc9Jlqo9d0HFO2ETL01TtrkvcyKPXh56mno+9UK7cx+Hm/APgL4H2HetHDvtSr6qD/gUmeTrKpqvZ1v7LuX2PY+cDvJrkUOAZ4a5J/qKo/2GA5V77W80m+AVwMTLzUJ5E1yVEMCv3Gqrp10hlfM8n3dQ76TLWxEabj2AgZ+uqVNcm7GJxuvaSqfjajbKut632tqnuTnJHkpKo66KRkrZ9+uQPY2i1vBW5fPaCq/rSqTquqRQZTHfz7pAu9h6E5kyx0R+gkeTODv60fm1XAFfpkDXA9sLeqrpthttWGZp2zPlNt3AF8sLsK5jzgv147pbTBcm4UQ7Mm2QzcClxdVT+YQ8bX9Mn6ju7nie7Kp6OBQ/8lNI9PfWd1A34B2AU83t2f2G3/ReDuNcZfyHyufhmak8En398DHmZwdP5nG/U9ZfDrbHVZd3e3Szdi1m79JmAfgw/5ngKumWHGSxlcHfQE8Olu24eBD3fLYfCP0TwBPAIszenPfVjOt3Xv3QvA893yWzdo1q8Az634f3N5Hjl7Zv0T4NEu533Ae4a9ptMESFJDWj/9IklvKJa6JDXEUpekhljqktQQS12SGmKpS1JDLHVJasj/Ak2EtMGQYXUOAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(logdetJs)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "f09d4628",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-0.00336791, dtype=float64)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"free_energy(works)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d6552121",
"metadata": {},
"outputs": [],
"source": [
"initxs, initvs = pull_out_pv(inits)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "a7a077bb",
"metadata": {},
"outputs": [],
"source": [
"finalxs, finalvs = pull_out_pv(finals)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "e0c218ea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 5., 7., 22., 38., 38., 44., 26., 15., 3., 2.]),\n",
" array([1.12576775, 1.30081435, 1.47586095, 1.65090754, 1.82595414,\n",
" 2.00100073, 2.17604733, 2.35109392, 2.52614052, 2.70118712,\n",
" 2.87623371]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAN0klEQVR4nO3df6jd9X3H8efLNMVSHRpyk2ZVm43FMSdU5eIcwqjTFOfGkj/msLAuDCFsbGDdYMQONvqXbn/UMSiM0MruWNtNaF2CtF2zLFIGne2N06rEGlecE0Nya9eqbGzo3vvjfgNpem/O954f95z7yfMBl++P8z33vPgYX3zu95zv96SqkCRtfJdMO4AkaTwsdElqhIUuSY2w0CWpERa6JDXiXev5Ylu3bq2dO3eu50tK0oZ3/Pjx71bV3KDj1rXQd+7cyeLi4nq+pCRteEn+vc9xnnKRpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGrOuVolrdw0deHOp59+++dsxJJG1UztAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDWi11fQJXkZeBN4B3i7quaTbAH+DtgJvAz8elX952RiSpIGWcsM/baquqGq5rvtA8DRqtoFHO22JUlTMsoplz3AQre+AOwdOY0kaWh9C72AryY5nmR/t297VZ0C6JbbVnpikv1JFpMsLi0tjZ5YkrSiXufQgVur6rUk24AjSV7o+wJVdRA4CDA/P19DZJQk9dBrhl5Vr3XLM8BjwM3A6SQ7ALrlmUmFlCQNNrDQk7w3yeVn14EPA88Bh4F93WH7gEOTCilJGqzPKZftwGNJzh7/uar6SpJvAo8muRd4Bbh7cjElSYMMLPSq+g7wwRX2vw7cPolQkqS180pRSWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1Ij+n5JtCSt7tiDwz3vtgfGm+Mi5wxdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqRO9CT7Ipyb8mebzb3pLkSJKT3fLKycWUJA2ylhn6fcCJc7YPAEerahdwtNuWJE1Jr0JPchXwy8Cnz9m9B1jo1heAvWNNJklak74z9D8H/hD4v3P2ba+qUwDdcttKT0yyP8liksWlpaVRskqSLmBgoSf5FeBMVR0f5gWq6mBVzVfV/Nzc3DC/QpLUQ58vuLgV+NUkdwGXAj+W5G+A00l2VNWpJDuAM5MMKkm6sIEz9Kp6oKquqqqdwD3AP1XVbwCHgX3dYfuAQxNLKUkaaJTPoT8E7E5yEtjdbUuSpmRN3ylaVU8AT3TrrwO3jz+SJGkYXikqSY1Y0wxd0gZw7MFpJ9CUOEOXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIv4Jug3v4yItrfs79u6+dQBJJ0+YMXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRgws9CSXJvlGkmeSPJ/kE93+LUmOJDnZLa+cfFxJ0mr6zND/B/jFqvogcANwZ5JbgAPA0araBRzttiVJUzKw0GvZW93m5u6ngD3AQrd/Adg7iYCSpH56XfqfZBNwHPgp4FNV9WSS7VV1CqCqTiXZtspz9wP7Aa655prxpF4Hw1xSD15WL2l6er0pWlXvVNUNwFXAzUmu7/sCVXWwquaran5ubm7ImJKkQdb0KZeq+j7wBHAncDrJDoBueWbc4SRJ/fX5lMtckiu69fcAdwAvAIeBfd1h+4BDE8ooSeqhzzn0HcBCdx79EuDRqno8ydeBR5PcC7wC3D3BnJKkAQYWelV9C7hxhf2vA7dPIpQkae28UlSSGmGhS1IjLHRJaoSFLkmNsNAlqRG9Lv1Xf8PeMkCSRuUMXZIaYaFLUiMsdElqhOfQJU3PsQeHf+5tD4wvRyOcoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RG+LHFi9Cwtye4f/e1Y04iaZycoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RG+LFFaRaNchdCXbScoUtSIyx0SWqEhS5JjbDQJakRAws9ydVJjiU5keT5JPd1+7ckOZLkZLe8cvJxJUmr6TNDfxv4g6r6GeAW4HeTXAccAI5W1S7gaLctSZqSgYVeVaeq6qlu/U3gBPB+YA+w0B22AOydUEZJUg9rOoeeZCdwI/AksL2qTsFy6QPbVnnO/iSLSRaXlpZGjCtJWk3vQk9yGfAF4GNV9Ubf51XVwaqar6r5ubm5YTJKknroVehJNrNc5p+tqi92u08n2dE9vgM4M5mIkqQ++nzKJcBngBNV9clzHjoM7OvW9wGHxh9PktRXn3u53Ap8FHg2ydPdvo8DDwGPJrkXeAW4eyIJJUm9DCz0qvpnIKs8fPt440iShuWVopLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqRJ+7LUrS7Dn24HDPu+2B8eaYIc7QJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJasTAQk/ySJIzSZ47Z9+WJEeSnOyWV042piRpkD4z9L8C7jxv3wHgaFXtAo5225KkKRpY6FX1NeB75+3eAyx06wvA3vHGkiSt1bDn0LdX1SmAbrlttQOT7E+ymGRxaWlpyJeTJA0y8TdFq+pgVc1X1fzc3NykX06SLlrDFvrpJDsAuuWZ8UWSJA1j2EI/DOzr1vcBh8YTR5I0rD4fW/w88HXgp5O8muRe4CFgd5KTwO5uW5I0Re8adEBVfWSVh24fcxZJ0gi8UlSSGjFwhj4rHj7y4lDPu3/3tWNOIkmzyRm6JDXCQpekRljoktSIDXMOXdPn+xhDOPbgtBPoIuIMXZIaYaFLUiMsdElqhOfQNbM8Zy+tjTN0SWqEhS5JjfCUiyZu2FMnM8OPHmqDcIYuSY2w0CWpERa6JDWi+XPoG/78rST15AxdkhphoUtSIyx0SWpE8+fQJemHjHJdwW0PjC/HBDhDl6RGWOiS1AhPuUhSXzN+usYZuiQ1wkKXpEZY6JLUCAtdkhoxUqEnuTPJt5O8lOTAuEJJktZu6EJPsgn4FPBLwHXAR5JcN65gkqS1GWWGfjPwUlV9p6r+F/hbYM94YkmS1mqUz6G/H/iPc7ZfBX7u/IOS7Af2d5tvJfn2CK85SVuB7047RA8bIedUM/5+/0Mdy/HZCDmnnPHjfQ9cKecH+jxxlELPCvvqR3ZUHQQOjvA66yLJYlXNTzvHIBsh50bICBsj50bICBsj50bICKPlHOWUy6vA1edsXwW8NsLvkySNYJRC/yawK8lPJHk3cA9weDyxJElrNfQpl6p6O8nvAf8AbAIeqarnx5Zs/c38aaHORsi5ETLCxsi5ETLCxsi5ETLCCDlT9SOnvSVJG5BXikpSIyx0SWrERVvoSbYkOZLkZLe8cpXjXk7ybJKnkyyuU7YL3lIhy/6ie/xbSW5aj1xD5PxQkh90Y/d0kj+eQsZHkpxJ8twqj099LHtknIVxvDrJsSQnkjyf5L4VjpmFseyTcxbG89Ik30jyTJfzEyscs/bxrKqL8gf4M+BAt34A+NNVjnsZ2LqOuTYB/wb8JPBu4BnguvOOuQv4MsvXAtwCPDmF8euT80PA41P+7/wLwE3Ac6s8PgtjOSjjLIzjDuCmbv1y4MUZ/XfZJ+csjGeAy7r1zcCTwC2jjudFO0Nn+TYFC936ArB3elF+SJ9bKuwB/rqW/QtwRZIdM5hz6qrqa8D3LnDI1MeyR8apq6pTVfVUt/4mcILlq8XPNQtj2Sfn1HVj9Fa3ubn7Of8TKmsez4u50LdX1SlY/kcAbFvluAK+muR4dxuDSVvplgrn/4Psc8yk9c3w892flV9O8rPrE21NZmEs+5iZcUyyE7iR5VnluWZqLC+QE2ZgPJNsSvI0cAY4UlUjj2fT3yma5B+B963w0B+t4dfcWlWvJdkGHEnyQjejmpQ+t1TodduFCeuT4SngA1X1VpK7gL8Hdk062BrNwlgOMjPjmOQy4AvAx6rqjfMfXuEpUxnLATlnYjyr6h3ghiRXAI8lub6qzn0fZc3j2fQMvaruqKrrV/g5BJw+++dLtzyzyu94rVueAR5j+VTDJPW5pcIs3HZhYIaqeuPsn5VV9SVgc5Kt6xexl1kYywualXFMspnlkvxsVX1xhUNmYiwH5ZyV8Twnz/eBJ4A7z3tozePZdKEPcBjY163vAw6df0CS9ya5/Ow68GFgxU8ijFGfWyocBn6zexf8FuAHZ08fraOBOZO8L0m69ZtZ/vf2+jrnHGQWxvKCZmEcu9f/DHCiqj65ymFTH8s+OWdkPOe6mTlJ3gPcAbxw3mFrHs+mT7kM8BDwaJJ7gVeAuwGS/Djw6aq6C9jO8p9CsDxWn6uqr0wyVK1yS4Ukv909/pfAl1h+B/wl4L+A35pkphFy/hrwO0neBv4buKe6t+/XS5LPs/yphq1JXgX+hOU3oGZmLHtknPo4ArcCHwWe7c77wvL9YK85J+fUx5J+OWdhPHcAC1n+oqBLgEer6vFR/z/30n9JasTFfMpFkppioUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RG/D/+l6zvGl9WCwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(initxs.flatten(), alpha=0.5)\n",
"plt.hist(finalxs.flatten(), alpha=0.5)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "9cbca9b6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 3., 12., 12., 26., 47., 48., 26., 16., 6., 4.]),\n",
" array([-2.6243906 , -2.10165634, -1.57892207, -1.05618781, -0.53345355,\n",
" -0.01071928, 0.51201498, 1.03474924, 1.55748351, 2.08021777,\n",
" 2.60295203]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANBUlEQVR4nO3db4hl9X3H8fcnamswaaM4yjZKpw+WEAmNgcEGLKWNmtoasvaBJUlbFiosgYYaaGmmCTTaEthQCIHSB11q6JbmTwUTlGybut1E0kD+OFqTaNdUCRtjXdyJqUYpbTF++2DO0mWd9Z6ZuXfufN33C4bzZ86957PDzoff/u45Z1NVSJL6edW8A0iSNscCl6SmLHBJasoCl6SmLHBJaurc7TzZxRdfXIuLi9t5Sklq7/777/9BVS2cvn9bC3xxcZGVlZXtPKUktZfke+vtdwpFkpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpra1jsxpUkWlw/N5bzH9t8wl/NKW+EIXJKassAlqSkLXJKassAlqalRH2ImOQY8B/wYeKGqlpJcBPw9sAgcA36zqv5zNjElSafbyAj8V6rqyqpaGraXgSNVtRs4MmxLkrbJVqZQ9gAHh/WDwI1bTiNJGm1sgRdwT5L7k+wb9l1aVccBhuUl670wyb4kK0lWVldXt55YkgSMv5Hn6qp6MsklwOEkj4w9QVUdAA4ALC0t1SYySpLWMWoEXlVPDssTwOeAq4CnkuwCGJYnZhVSkvRSEws8yQVJXntyHXg78BBwN7B3OGwvcNesQkqSXmrMFMqlwOeSnDz+U1X1hST3AXckuRl4HLhpdjG1neb1PBJJGzOxwKvqu8Cb19n/NHDNLEJJkibzTkxJasoCl6SmfB64XhGOnf+erb3BrWOOeXZr55CmzBG4JDVlgUtSUxa4JDXlHLhmbsvz05LW5QhckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpoaXeBJzknyr0k+P2xflORwkkeH5YWziylJOt1GRuC3AEdP2V4GjlTVbuDIsC1J2iajCjzJZcANwF+fsnsPcHBYPwjcONVkkqSXNXYE/nHgj4AXT9l3aVUdBxiWl0w3miTp5Uws8CTvAE5U1f2bOUGSfUlWkqysrq5u5i0kSesYMwK/GnhnkmPAZ4C3Jfk74KkkuwCG5Yn1XlxVB6pqqaqWFhYWphRbkjSxwKvqj6vqsqpaBN4FfLGqfhu4G9g7HLYXuGtmKSVJL7GV68D3A9cleRS4btiWJG2TczdycFXdC9w7rD8NXDP9SJKkMbwTU5KassAlqakNTaFIZ7PF5UMzed9j+2+Yyfvqlc8RuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMTCzzJ+Um+keSbSR5Octuw/6Ikh5M8OiwvnH1cSdJJY0bg/wO8rareDFwJXJ/krcAycKSqdgNHhm1J0jaZWOC15vlh87zhq4A9wMFh/0HgxlkElCStb9QceJJzkjwInAAOV9XXgUur6jjAsLzkDK/dl2Qlycrq6uqUYkuSRhV4Vf24qq4ELgOuSvKmsSeoqgNVtVRVSwsLC5uMKUk63YauQqmqZ4B7geuBp5LsAhiWJ6YdTpJ0ZmOuQllI8rph/dXAtcAjwN3A3uGwvcBdM8ooSVrHuSOO2QUcTHIOa4V/R1V9PslXgTuS3Aw8Dtw0w5ySpNNMLPCq+hbwlnX2Pw1cM4tQkqTJvBNTkpqywCWpKQtckpqywCWpKQtckpqywCWpqTHXgWtOFpcPzTuCpB3MEbgkNWWBS1JTFrgkNeUc+Fnu2PnvmXeENmb2s7r11PVnZ3MOvSI5ApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWpqYoEnuTzJl5IcTfJwkluG/RclOZzk0WF54ezjSpJOGjMCfwH4g6p6I/BW4PeSXAEsA0eqajdwZNiWJG2TiQVeVcer6oFh/TngKPB6YA9wcDjsIHDjjDJKktaxoTnwJIvAW4CvA5dW1XFYK3ngkjO8Zl+SlSQrq6urW4wrSTppdIEneQ1wJ/D+qvrR2NdV1YGqWqqqpYWFhc1klCStY1SBJzmPtfL+ZFV9dtj9VJJdw/d3ASdmE1GStJ4xV6EEuB04WlUfO+VbdwN7h/W9wF3TjydJOpNzRxxzNfA7wLeTPDjs+yCwH7gjyc3A48BNM0koSVrXxAKvqq8AOcO3r5luHEnSWN6JKUlNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1NSYZ6FI2i63/vQ2nOPZ2Z9D28IRuCQ1ZYFLUlMWuCQ1ZYFLZ5nF5UMsLh+adwxNgQUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU35PPDN2obnNh87f+ankNSYI3BJasoCl6SmLHBJasoCl6SmJhZ4kk8kOZHkoVP2XZTkcJJHh+WFs40pSTrdmBH43wDXn7ZvGThSVbuBI8O2JGkbTSzwqvoy8MPTdu8BDg7rB4EbpxtLkjTJZufAL62q4wDD8pIzHZhkX5KVJCurq6ubPJ0k6XQz/xCzqg5U1VJVLS0sLMz6dJJ01thsgT+VZBfAsDwxvUiSpDE2W+B3A3uH9b3AXdOJI0kaa8xlhJ8Gvgq8IckTSW4G9gPXJXkUuG7YliRto4kPs6qqd5/hW9dMOYskaQO8E1OSmrLAJampNs8DX1w+NLdzH9t/w9zOLUln4ghckpqywCWpKQtckppqMwcuabr8XKk/R+CS1JQFLklNWeCS1JRz4NJZ5tj575n5ORb/+1MzP4ccgUtSWxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSU14HLmnqJl5rfusUTnLrs1N4k94cgUtSUxa4JDVlgUtSU6/IOfCpP+vh1um+nSRNgyNwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrqFXkduCStZ3H50NzOfWz/DVN/zy2NwJNcn+Q7SR5LsjytUJKkyTZd4EnOAf4S+DXgCuDdSa6YVjBJ0svbygj8KuCxqvpuVf0v8Blgz3RiSZIm2coc+OuB75+y/QTwC6cflGQfsG/YfD7Jd17mPS8GfrCFTGvn3OobbM5Uss+R+efL/Bt121R/02eePx/d0st/dr2dWynw9X569ZIdVQeAA6PeMFmpqqUtZJqbztnB/PNm/vnqmn8rUyhPAJefsn0Z8OTW4kiSxtpKgd8H7E7yc0l+AngXcPd0YkmSJtn0FEpVvZDkfcA/AecAn6iqh7eYZ9RUyw7VOTuYf97MP18t86fqJdPWkqQGvJVekpqywCWpqR1V4En+LMm3kjyY5J4kPzPvTBuR5M+TPDL8GT6X5HXzzrQRSW5K8nCSF5O0uaSq8yMdknwiyYkkD807y2YkuTzJl5IcHf7u3DLvTBuR5Pwk30jyzSH/bfPOtBE7ag48yU9V1Y+G9d8Hrqiq98451mhJ3g58cfiA96MAVfWBOccaLckbgReBvwL+sKpW5hxpouGRDv8OXMfapa33Ae+uqn+ba7CRkvwS8Dzwt1X1pnnn2agku4BdVfVAktcC9wM3Nvr5B7igqp5Pch7wFeCWqvranKONsqNG4CfLe3AB69wYtJNV1T1V9cKw+TXWro1vo6qOVtXL3Sm7E7V+pENVfRn44bxzbFZVHa+qB4b154CjrN2l3UKteX7YPG/4atM7O6rAAZJ8JMn3gd8C/mTeebbgd4F/nHeIs8B6j3RoUyCvJEkWgbcAX59zlA1Jck6SB4ETwOGqapN/2ws8yT8neWidrz0AVfWhqroc+CTwvu3ON8mk/MMxHwJeYO3PsKOMyd/MqEc6aLaSvAa4E3j/af+S3vGq6sdVdSVr/2K+Kkmbqaxt/w8dqurakYd+CjgEfHiGcTZsUv4ke4F3ANfUTvqAYbCBn38XPtJhzoa54zuBT1bVZ+edZ7Oq6pkk9wLXAy0+VN5RUyhJdp+y+U7gkXll2Ywk1wMfAN5ZVf817zxnCR/pMEfDh4C3A0er6mPzzrNRSRZOXi2W5NXAtTTqnZ12FcqdwBtYuxLie8B7q+o/5ptqvCSPAT8JPD3s+lqzq2h+A/gLYAF4Bniwqn51rqFGSPLrwMf5/0c6fGS+icZL8mngl1l7nOlTwIer6va5htqAJL8I/AvwbdZ+bwE+WFX/ML9U4yX5eeAga393XgXcUVV/Ot9U4+2oApckjbejplAkSeNZ4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU39H3kbJQE1TQQvAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(initvs.flatten())\n",
"plt.hist(finalvs.flatten())"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "f6ec1c54",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-0.00336791, dtype=float64)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"free_energy(works)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22c4c431",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment