Skip to content

Instantly share code, notes, and snippets.

@Alescontrela
Created April 23, 2024 21:17
Show Gist options
  • Save Alescontrela/a97ae89929b8589f7a81e0c6cbb7a7a1 to your computer and use it in GitHub Desktop.
Save Alescontrela/a97ae89929b8589f7a81e0c6cbb7a7a1 to your computer and use it in GitHub Desktop.
MPO Notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-17 07:22:39.615782: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2023-11-17 07:22:39.615847: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2023-11-17 07:22:39.615868: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2023-11-17 07:22:40.371174: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"source": [
"import optax\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import os\n",
"import numpy as np\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '5'\n",
"\n",
"from tensorflow_probability.substrates import jax as tfp\n",
"tfd = tfp.distributions\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.9337807\n",
"3.9337807\n",
"[[0.11920292 0.11920292]\n",
" [0.8807971 0.8807971 ]]\n"
]
}
],
"source": [
"def logsumexp1(a):\n",
" return jnp.mean(jax.scipy.special.logsumexp(a, axis=0) - jnp.log(a.shape[0]))\n",
"\n",
"def logsumexp2(b):\n",
" return jnp.mean(jnp.log(jnp.mean(jnp.exp(b), axis=0)))\n",
"\n",
"a = jnp.array([[1., 4.], [3., 6.]])\n",
"\n",
"print(logsumexp1(a))\n",
"print(logsumexp2(a))\n",
"print(jax.nn.softmax(a, axis=0))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "TracerBoolConversionError",
"evalue": "Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function loss_fn at /tmp/ipykernel_2045319/1889768006.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument params['mean'].\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTracerBoolConversionError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/ale/awake/embodied/agents/mpo/test_kl.ipynb Cell 2\u001b[0m line \u001b[0;36m8\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=82'>83</a>\u001b[0m opt_state \u001b[39m=\u001b[39m optimizer\u001b[39m.\u001b[39minit(params)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=84'>85</a>\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m1000\u001b[39m):\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=85'>86</a>\u001b[0m (loss, metrics), grads \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39;49mvalue_and_grad(loss_fn, has_aux\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)(params, key)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=86'>87</a>\u001b[0m updates, opt_state \u001b[39m=\u001b[39m optimizer\u001b[39m.\u001b[39mupdate(grads, opt_state)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=87'>88</a>\u001b[0m params \u001b[39m=\u001b[39m optax\u001b[39m.\u001b[39mapply_updates(params, updates)\n",
" \u001b[0;31m[... skipping hidden 20 frame]\u001b[0m\n",
"\u001b[1;32m/home/ale/awake/embodied/agents/mpo/test_kl.ipynb Cell 2\u001b[0m line \u001b[0;36m1\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m temperature \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39msoftplus(params[\u001b[39m'\u001b[39m\u001b[39mlog_temperature\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m+\u001b[39m \u001b[39m1e-8\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a>\u001b[0m penalty_temperature \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39msoftplus(params[\u001b[39m'\u001b[39m\u001b[39mlog_penalty_temperature\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m+\u001b[39m \u001b[39m1e-8\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a>\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m jnp\u001b[39m.\u001b[39misnan(mean)\u001b[39m.\u001b[39many(), mean\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=13'>14</a>\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m jnp\u001b[39m.\u001b[39misnan(temperature)\u001b[39m.\u001b[39many(), temperature\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W1sdnNjb2RlLXJlbW90ZQ%3D%3D?line=14'>15</a>\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m jnp\u001b[39m.\u001b[39misnan(penalty_temperature)\u001b[39m.\u001b[39many(), penalty_temperature\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/core.py:1443\u001b[0m, in \u001b[0;36mconcretization_function_error.<locals>.error\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 1442\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39merror\u001b[39m(\u001b[39mself\u001b[39m, arg):\n\u001b[0;32m-> 1443\u001b[0m \u001b[39mraise\u001b[39;00m TracerBoolConversionError(arg)\n",
"\u001b[0;31mTracerBoolConversionError\u001b[0m: Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function loss_fn at /tmp/ipykernel_2045319/1889768006.py:7 for jit. This concrete value was not available in Python because it depends on the value of the argument params['mean'].\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError"
]
}
],
"source": [
"optimizer = optax.adam(1e-4)\n",
"key = jax.random.PRNGKey(42)\n",
"\n",
"def network(params):\n",
" return tfd.Independent(tfd.Normal(params, jnp.ones_like(params)), 1)\n",
"\n",
"@jax.jit\n",
"def loss_fn(params, key):\n",
" metrics = {}\n",
" mean = params['mean']\n",
" temperature = jax.nn.softplus(params['log_temperature'] + 1e-8)\n",
" penalty_temperature = jax.nn.softplus(params['log_penalty_temperature'] + 1e-8)\n",
" assert not jnp.isnan(mean).any(), mean\n",
" assert not jnp.isnan(temperature).any(), temperature\n",
" assert not jnp.isnan(penalty_temperature).any(), penalty_temperature\n",
"\n",
"\n",
" dist = network(mean)\n",
" a_improvement = dist.sample(20, key)\n",
" q_improvement = jax.lax.stop_gradient(jnp.sum(jnp.exp(-10 * (a_improvement - 0.75)**2), -1))\n",
" assert jnp.isfinite(q_improvement).all(), q_improvement\n",
"\n",
" # print('a_improvement', a_improvement.shape)\n",
" # print('q_improvement', q_improvement.shape)\n",
"\n",
" def compute_weights_and_temperature_loss(q_values, epsilon, temperature):\n",
" tempered_q_values = jax.lax.stop_gradient(q_values) / temperature\n",
" assert not jnp.isnan(tempered_q_values).any(), tempered_q_values\n",
" q_logsumexp = jnp.mean(jnp.log(jnp.mean(jnp.exp(tempered_q_values), axis=0)))\n",
" assert not jnp.isnan(q_logsumexp).any(), tempered_q_values\n",
" loss_temperature = jnp.mean(temperature * epsilon + temperature * q_logsumexp)\n",
" normalized_weights = jax.lax.stop_gradient(jax.nn.softmax(tempered_q_values, axis=0))\n",
" return normalized_weights, loss_temperature\n",
"\n",
" def compute_nonparametric_kl_from_normalized_weights(normalized_weights):\n",
" num_action_samples = normalized_weights.shape[0] / 1.\n",
" integrand = jnp.log(num_action_samples * normalized_weights + 1e-8)\n",
" non_parametric_kl = jnp.sum(normalized_weights * integrand, axis=0)\n",
" return non_parametric_kl\n",
"\n",
" normalized_weights, loss_temperature = compute_weights_and_temperature_loss(q_improvement, 0.1, temperature)\n",
" metrics['loss_temperature'] = loss_temperature\n",
" metrics['non_parametric_kl'] = compute_nonparametric_kl_from_normalized_weights(normalized_weights)\n",
"\n",
" assert not jnp.isnan(normalized_weights).any(), q_improvement\n",
" assert not jnp.isnan(loss_temperature).any(), temperature\n",
"\n",
" # print('normalized_weights', normalized_weights.shape)\n",
" # print('loss_temperature', loss_temperature.shape)\n",
"\n",
" cost_out_of_bound = -jnp.linalg.norm(a_improvement - jnp.clip(a_improvement, -1.0, 1.0), axis=-1)\n",
" penalty_normalized_weights, loss_penalty_temperature = compute_weights_and_temperature_loss(cost_out_of_bound, 0.001, penalty_temperature)\n",
" metrics['loss_penalty_temperature'] = loss_penalty_temperature\n",
" metrics['penalty_non_parametric_kl'] = compute_nonparametric_kl_from_normalized_weights(penalty_normalized_weights)\n",
"\n",
" assert not jnp.isnan(cost_out_of_bound).any(), q_improvement\n",
" assert not jnp.isnan(penalty_normalized_weights).any(), temperature\n",
" assert not jnp.isnan(loss_penalty_temperature).any(), temperature\n",
"\n",
" # print('cost_out_of_bound', cost_out_of_bound.shape)\n",
" # print('penalty_normalized_weights', penalty_normalized_weights.shape)\n",
" # print('loss_penalty_temperature', loss_penalty_temperature.shape)\n",
"\n",
" loss_temperature += loss_penalty_temperature\n",
" normalized_weights += penalty_normalized_weights\n",
"\n",
" logpi = dist.log_prob(jax.lax.stop_gradient(a_improvement))\n",
" loss_dist = jnp.mean(-jnp.sum(normalized_weights * logpi, axis=0))\n",
" metrics['loss_dist'] = loss_dist\n",
" assert not jnp.isnan(logpi).any(), temperature\n",
" assert not jnp.isnan(loss_dist).any(), temperature\n",
" # print('logpi', logpi.shape)\n",
" # print('loss_dist', loss_dist.shape)\n",
"\n",
" return loss_temperature + loss_dist, metrics\n",
"\n",
"key = jax.random.PRNGKey(41)\n",
"\n",
"params = {\n",
" 'mean': jnp.array([0.0]),\n",
" 'log_temperature': jnp.array([10.]),\n",
" 'log_penalty_temperature': jnp.array([10.])}\n",
"opt_state = optimizer.init(params)\n",
"\n",
"for i in range(1000):\n",
" (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, key)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" if i % 50 == 0:\n",
" print(f'Iteration {i}')\n",
" print(f'\\t Loss: {loss}')\n",
" print(f'\\t Params: {params}')\n",
" print(f'\\t Metrics: {metrics}')\n",
" print(f'\\t Updates: {updates}')\n",
" print(f'\\t Grad: {grads}')"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 0\n",
"\t Loss: -12.357587814331055\n",
"\t Params: {'log_penalty_temperature': Array([9.999944], dtype=float32), 'log_temperature': Array([10.], dtype=float32), 'mean': Array([ 9.96949 , 10.024852], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.36426342, dtype=float32), 'loss_penalty_temperature': Array(-12.721851, dtype=float32), 'loss_temperature': Array(1.0000046, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00043857, dtype=float32)}\n",
"\t Updates: {'log_penalty_temperature': Array([-5.614034e-05], dtype=float32), 'log_temperature': Array([-0.], dtype=float32), 'mean': Array([-0.03050957, 0.02485199], dtype=float32)}\n",
"\t Grad: {'log_penalty_temperature': Array([0.0005614], dtype=float32), 'log_temperature': Array([0.], dtype=float32), 'mean': Array([ 0.30509564, -0.24851994], dtype=float32)}\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/ale/awake/embodied/agents/mpo/test_kl.ipynb Cell 3\u001b[0m line \u001b[0;36m8\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=86'>87</a>\u001b[0m slowdist \u001b[39m=\u001b[39m tfd\u001b[39m.\u001b[39mIndependent(tfd\u001b[39m.\u001b[39mNormal(params[\u001b[39m'\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m'\u001b[39m], jax\u001b[39m.\u001b[39mlax\u001b[39m.\u001b[39mstop_gradient(std \u001b[39m*\u001b[39m jnp\u001b[39m.\u001b[39mones_like(params[\u001b[39m'\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m'\u001b[39m]))), \u001b[39m1\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=87'>88</a>\u001b[0m _, key \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39msplit(key)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=88'>89</a>\u001b[0m (loss, metrics), grads \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39;49mvalue_and_grad(loss_fn, has_aux\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)(params, key, slowdist)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=89'>90</a>\u001b[0m updates, opt_state \u001b[39m=\u001b[39m optimizer\u001b[39m.\u001b[39mupdate(grads, opt_state)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=90'>91</a>\u001b[0m params \u001b[39m=\u001b[39m optax\u001b[39m.\u001b[39mapply_updates(params, updates)\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/api.py:734\u001b[0m, in \u001b[0;36mvalue_and_grad.<locals>.value_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 732\u001b[0m ans, vjp_py \u001b[39m=\u001b[39m _vjp(f_partial, \u001b[39m*\u001b[39mdyn_args, reduce_axes\u001b[39m=\u001b[39mreduce_axes)\n\u001b[1;32m 733\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 734\u001b[0m ans, vjp_py, aux \u001b[39m=\u001b[39m _vjp(\n\u001b[1;32m 735\u001b[0m f_partial, \u001b[39m*\u001b[39;49mdyn_args, has_aux\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, reduce_axes\u001b[39m=\u001b[39;49mreduce_axes)\n\u001b[1;32m 736\u001b[0m _check_scalar(ans)\n\u001b[1;32m 737\u001b[0m tree_map(partial(_check_output_dtype_grad, holomorphic), ans)\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/api.py:2243\u001b[0m, in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m 2241\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2242\u001b[0m flat_fun, out_aux_trees \u001b[39m=\u001b[39m flatten_fun_nokwargs2(fun, in_tree)\n\u001b[0;32m-> 2243\u001b[0m out_primal, out_vjp, aux \u001b[39m=\u001b[39m ad\u001b[39m.\u001b[39;49mvjp(\n\u001b[1;32m 2244\u001b[0m flat_fun, primals_flat, has_aux\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, reduce_axes\u001b[39m=\u001b[39;49mreduce_axes)\n\u001b[1;32m 2245\u001b[0m out_tree, aux_tree \u001b[39m=\u001b[39m out_aux_trees()\n\u001b[1;32m 2246\u001b[0m out_primal_py \u001b[39m=\u001b[39m tree_unflatten(out_tree, out_primal)\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:142\u001b[0m, in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m 140\u001b[0m out_primals, pvals, jaxpr, consts \u001b[39m=\u001b[39m linearize(traceable, \u001b[39m*\u001b[39mprimals)\n\u001b[1;32m 141\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 142\u001b[0m out_primals, pvals, jaxpr, consts, aux \u001b[39m=\u001b[39m linearize(traceable, \u001b[39m*\u001b[39;49mprimals, has_aux\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 144\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39munbound_vjp\u001b[39m(pvals, jaxpr, consts, \u001b[39m*\u001b[39mcts):\n\u001b[1;32m 145\u001b[0m cts \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(ct \u001b[39mfor\u001b[39;00m ct, pval \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(cts, pvals) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m pval\u001b[39m.\u001b[39mis_known())\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:129\u001b[0m, in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 127\u001b[0m _, in_tree \u001b[39m=\u001b[39m tree_flatten(((primals, primals), {}))\n\u001b[1;32m 128\u001b[0m jvpfun_flat, out_tree \u001b[39m=\u001b[39m flatten_fun(jvpfun, in_tree)\n\u001b[0;32m--> 129\u001b[0m jaxpr, out_pvals, consts \u001b[39m=\u001b[39m pe\u001b[39m.\u001b[39;49mtrace_to_jaxpr_nounits(jvpfun_flat, in_pvals)\n\u001b[1;32m 130\u001b[0m out_primals_pvals, out_tangents_pvals \u001b[39m=\u001b[39m tree_unflatten(out_tree(), out_pvals)\n\u001b[1;32m 131\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mall\u001b[39m(out_primal_pval\u001b[39m.\u001b[39mis_known() \u001b[39mfor\u001b[39;00m out_primal_pval \u001b[39min\u001b[39;00m out_primals_pvals)\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/profiler.py:340\u001b[0m, in \u001b[0;36mannotate_function.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[39m@wraps\u001b[39m(func)\n\u001b[1;32m 338\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mwrapper\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 339\u001b[0m \u001b[39mwith\u001b[39;00m TraceAnnotation(name, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mdecorator_kwargs):\n\u001b[0;32m--> 340\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 341\u001b[0m \u001b[39mreturn\u001b[39;00m wrapper\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:774\u001b[0m, in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m 772\u001b[0m \u001b[39mwith\u001b[39;00m core\u001b[39m.\u001b[39mnew_main(JaxprTrace, name_stack\u001b[39m=\u001b[39mcurrent_name_stack) \u001b[39mas\u001b[39;00m main:\n\u001b[1;32m 773\u001b[0m fun \u001b[39m=\u001b[39m trace_to_subjaxpr_nounits(fun, main, instantiate)\n\u001b[0;32m--> 774\u001b[0m jaxpr, (out_pvals, consts, env) \u001b[39m=\u001b[39m fun\u001b[39m.\u001b[39;49mcall_wrapped(pvals)\n\u001b[1;32m 775\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m env\n\u001b[1;32m 776\u001b[0m \u001b[39mdel\u001b[39;00m main, fun, env\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/linear_util.py:191\u001b[0m, in \u001b[0;36mWrappedFun.call_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 188\u001b[0m gen \u001b[39m=\u001b[39m gen_static_args \u001b[39m=\u001b[39m out_store \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 190\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 191\u001b[0m ans \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mf(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49m\u001b[39mdict\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mparams, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n\u001b[1;32m 192\u001b[0m \u001b[39mexcept\u001b[39;00m:\n\u001b[1;32m 193\u001b[0m \u001b[39m# Some transformations yield from inside context managers, so we have to\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[39m# interrupt them before reraising the exception. Otherwise they will only\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[39m# get garbage-collected at some later time, running their cleanup tasks\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[39m# only after this exception is handled, which can corrupt the global\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m# state.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[39mwhile\u001b[39;00m stack:\n",
"\u001b[1;32m/home/ale/awake/embodied/agents/mpo/test_kl.ipynb Cell 3\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=41'>42</a>\u001b[0m diff_out_of_bound \u001b[39m=\u001b[39m a_improvement \u001b[39m-\u001b[39m jnp\u001b[39m.\u001b[39mclip(a_improvement, \u001b[39m-\u001b[39m\u001b[39m1.0\u001b[39m, \u001b[39m1.0\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=42'>43</a>\u001b[0m cost_out_of_bound \u001b[39m=\u001b[39m \u001b[39m-\u001b[39mjnp\u001b[39m.\u001b[39mlinalg\u001b[39m.\u001b[39mnorm(diff_out_of_bound, axis\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=43'>44</a>\u001b[0m penalty_normalized_weights, loss_penalty_temperature, penalty_non_parametric_kl \u001b[39m=\u001b[39m compute_weights_and_temperature_loss(cost_out_of_bound, \u001b[39m0.001\u001b[39;49m, penalty_temperature)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=44'>45</a>\u001b[0m metrics[\u001b[39m'\u001b[39m\u001b[39mloss_penalty_temperature\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m loss_penalty_temperature\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=45'>46</a>\u001b[0m metrics[\u001b[39m'\u001b[39m\u001b[39mpenalty_non_parametric_kl\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m penalty_non_parametric_kl\n",
"\u001b[1;32m/home/ale/awake/embodied/agents/mpo/test_kl.ipynb Cell 3\u001b[0m line \u001b[0;36m2\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcompute_weights_and_temperature_loss\u001b[39m(q_values, epsilon, temperature):\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a>\u001b[0m tempered_q_values \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mlax\u001b[39m.\u001b[39mstop_gradient(q_values) \u001b[39m/\u001b[39m temperature\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m jnp\u001b[39m.\u001b[39;49misnan(tempered_q_values)\u001b[39m.\u001b[39;49many(), tempered_q_values\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a>\u001b[0m q_logsumexp \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mmean(jnp\u001b[39m.\u001b[39mlog(jnp\u001b[39m.\u001b[39mmean(jnp\u001b[39m.\u001b[39mexp(tempered_q_values), axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)))\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=29'>30</a>\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m jnp\u001b[39m.\u001b[39misnan(q_logsumexp)\u001b[39m.\u001b[39many(), tempered_q_values\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/numpy/reductions.py:296\u001b[0m, in \u001b[0;36many\u001b[0;34m(a, axis, out, keepdims, where)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[39m@_wraps\u001b[39m(np\u001b[39m.\u001b[39many, skip_params\u001b[39m=\u001b[39m[\u001b[39m'\u001b[39m\u001b[39mout\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[1;32m 294\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39many\u001b[39m(a: ArrayLike, axis: Axis \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, out: \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 295\u001b[0m keepdims: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m, \u001b[39m*\u001b[39m, where: Optional[ArrayLike] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Array:\n\u001b[0;32m--> 296\u001b[0m \u001b[39mreturn\u001b[39;00m _reduce_any(a, axis\u001b[39m=\u001b[39;49m_ensure_optional_axes(axis), out\u001b[39m=\u001b[39;49mout,\n\u001b[1;32m 297\u001b[0m keepdims\u001b[39m=\u001b[39;49mkeepdims, where\u001b[39m=\u001b[39;49mwhere)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"optimizer = optax.sgd(1e-1)\n",
"key = jax.random.PRNGKey(41)\n",
"target = -0.75\n",
"n_act_samples = 100\n",
"std = 0.3\n",
"slowdist_update_freq = 1e9\n",
"use_objective_constraint = False\n",
"\n",
"\n",
"def loss_fn(params, key, slowdist):\n",
" metrics = {}\n",
" mean = params['mean']\n",
" temperature = jax.nn.softplus(params['log_temperature']) + 1e-8\n",
" penalty_temperature = jax.nn.softplus(params['log_penalty_temperature']) + 1e-8\n",
" assert not jnp.isnan(mean).any(), mean\n",
" assert not jnp.isnan(temperature).any(), temperature\n",
" assert not jnp.isnan(penalty_temperature).any(), penalty_temperature\n",
"\n",
" dist = tfd.Independent(tfd.Normal(mean, jax.lax.stop_gradient(std * jnp.ones_like(mean))), reinterpreted_batch_ndims=1)\n",
" a_improvement = dist.sample(n_act_samples, seed=key)\n",
" assert a_improvement.shape == (n_act_samples,) + mean.shape, a_improvement.shape\n",
" q_improvement = jax.lax.stop_gradient(jnp.sum(jnp.exp(-30 * (a_improvement - target)**2), -1))\n",
"\n",
" assert jnp.isfinite(q_improvement).all(), q_improvement\n",
"\n",
" def compute_weights_and_temperature_loss(q_values, epsilon, temperature):\n",
" tempered_q_values = jax.lax.stop_gradient(q_values) / temperature\n",
" assert not jnp.isnan(tempered_q_values).any(), tempered_q_values\n",
" q_logsumexp = jnp.mean(jnp.log(jnp.mean(jnp.exp(tempered_q_values), axis=0)))\n",
" assert not jnp.isnan(q_logsumexp).any(), tempered_q_values\n",
" loss_temperature = jnp.mean(temperature * epsilon + temperature * q_logsumexp)\n",
" normalized_weights = jax.lax.stop_gradient(jax.nn.softmax(tempered_q_values, axis=0))\n",
" num_action_samples = normalized_weights.shape[0]\n",
" integrand = jnp.log(num_action_samples * normalized_weights + 1e-8)\n",
" non_parametric_kl = jnp.sum(normalized_weights * integrand, axis=0)\n",
" return normalized_weights, loss_temperature, non_parametric_kl\n",
"\n",
" normalized_weights, loss_temperature, non_parametric_kl = compute_weights_and_temperature_loss(q_improvement, 0.1, temperature)\n",
" metrics['loss_temperature'] = loss_temperature\n",
" metrics['non_parametric_kl'] = non_parametric_kl\n",
"\n",
" diff_out_of_bound = a_improvement - jnp.clip(a_improvement, -1.0, 1.0)\n",
" cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1)\n",
" penalty_normalized_weights, loss_penalty_temperature, penalty_non_parametric_kl = compute_weights_and_temperature_loss(cost_out_of_bound, 0.001, penalty_temperature)\n",
" metrics['loss_penalty_temperature'] = loss_penalty_temperature\n",
" metrics['penalty_non_parametric_kl'] = penalty_non_parametric_kl\n",
"\n",
" # print('SAMPLED ITEMS')\n",
" # print('actions: ', a_improvement)\n",
" # print('costs:', cost_out_of_bound)\n",
" # print('penalties: ', penalty_normalized_weights)\n",
"\n",
" assert not jnp.isnan(cost_out_of_bound).any(), q_improvement\n",
" assert not jnp.isnan(penalty_normalized_weights).any(), temperature\n",
" assert not jnp.isnan(loss_penalty_temperature).any(), temperature\n",
"\n",
" # print('cost_out_of_bound', cost_out_of_bound.shape)\n",
" # print('penalty_normalized_weights', penalty_normalized_weights.shape)\n",
" # print('loss_penalty_temperature', loss_penalty_temperature.shape)\n",
" if use_objective_constraint:\n",
" loss_temperature += loss_penalty_temperature\n",
" normalized_weights += penalty_normalized_weights\n",
" else:\n",
" loss_temperature = loss_penalty_temperature\n",
" normalized_weights = penalty_normalized_weights\n",
"\n",
" logpi = dist.log_prob(jax.lax.stop_gradient(a_improvement))\n",
" loss_dist = jnp.mean(-jnp.sum(normalized_weights * logpi, axis=0))\n",
" metrics['loss_dist'] = loss_dist\n",
" assert not jnp.isnan(logpi).any(), temperature\n",
" assert not jnp.isnan(loss_dist).any(), temperature\n",
" loss = loss_dist + loss_temperature\n",
"\n",
" return loss, metrics\n",
"\n",
"\n",
"params = {\n",
" 'mean': jnp.array([10.0, 10.0]),\n",
" 'log_temperature': jnp.array([10.]),\n",
" 'log_penalty_temperature': jnp.array([10.]),\n",
" 'log_alpha': jnp.array([10., 10.]),}\n",
"opt_state = optimizer.init(params)\n",
"slowdist = tfd.Independent(tfd.Normal(params['mean'], jax.lax.stop_gradient(std * jnp.ones_like(params['mean']))), 1)\n",
"\n",
"means = []\n",
"for i in range(3000):\n",
" if i % slowdist_update_freq == 0:\n",
" slowdist = tfd.Independent(tfd.Normal(params['mean'], jax.lax.stop_gradient(std * jnp.ones_like(params['mean']))), 1)\n",
" _, key = jax.random.split(key)\n",
" (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, key, slowdist)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" means.append(params['mean'])\n",
" if i % 50 == 0:\n",
" print(f'Iteration {i}')\n",
" print(f'\\t Loss: {loss}')\n",
" print(f'\\t Params: {params}')\n",
" print(f'\\t Metrics: {metrics}')\n",
" print(f'\\t Updates: {updates}')\n",
" print(f'\\t Grad: {grads}')"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"a = np.array(means)\n",
"plt.scatter(a[:, 0], a[:, 1])\n",
"plt.scatter(a[0, 0], a[0, 1], color='g')\n",
"# plt.scatter(target, target, color='r')\n",
"plt.plot([-1, 1], [1, 1], color='r')\n",
"plt.plot([-1, 1], [-1, -1], color='r')\n",
"plt.plot([-1, -1], [-1, 1], color='r')\n",
"plt.plot([1, 1], [-1, 1], color='r')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(means)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 0\n",
"\t Loss: 0.48589494824409485\n",
"\t Params: {'log_alpha': Array([9.99, 9.99], dtype=float32), 'log_penalty_temperature': Array([9.99], dtype=float32), 'log_temperature': Array([9.99], dtype=float32), 'mean': Array([1.99, 2.01], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7199209, dtype=float32), 'loss_penalty_temperature': Array(-1.4340315, dtype=float32), 'loss_temperature': Array(1.0000046, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00042423, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00999992, -0.00999992], dtype=float32), 'log_penalty_temperature': Array([-0.00999976], dtype=float32), 'log_temperature': Array([-0.00999993], dtype=float32), 'mean': Array([-0.00999993, 0.00999993], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00999954, 0.00999954], dtype=float32), 'log_penalty_temperature': Array([0.00057578], dtype=float32), 'log_temperature': Array([0.09999543], dtype=float32), 'mean': Array([ 0.5403554, -0.5681403], dtype=float32)}\n",
"\t Slowdist: (Array([2., 2.], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 100\n",
"\t Loss: 0.38626179099082947\n",
"\t Params: {'log_alpha': Array([8.992878, 8.990478], dtype=float32), 'log_penalty_temperature': Array([9.024538], dtype=float32), 'log_temperature': Array([8.990006], dtype=float32), 'mean': Array([1.9993048, 2.0023365], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7204943, dtype=float32), 'loss_penalty_temperature': Array(-1.4142807, dtype=float32), 'loss_temperature': Array(0.9000129, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00057536, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.01000603, -0.00991931], dtype=float32), 'log_penalty_temperature': Array([-0.00877801], dtype=float32), 'log_temperature': Array([-0.00999959], dtype=float32), 'mean': Array([-0.0025167 , 0.00116931], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00999877, 0.00999877], dtype=float32), 'log_penalty_temperature': Array([0.00042453], dtype=float32), 'log_temperature': Array([0.0999877], dtype=float32), 'mean': Array([ 0.7919506 , -0.19585393], dtype=float32)}\n",
"\t Slowdist: (Array([2.0018215, 2.0011673], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 200\n",
"\t Loss: 0.5832449197769165\n",
"\t Params: {'log_alpha': Array([7.9902987, 7.9908733], dtype=float32), 'log_penalty_temperature': Array([8.1854515], dtype=float32), 'log_temperature': Array([7.990103], dtype=float32), 'mean': Array([2.009327 , 2.0057862], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.0127838, dtype=float32), 'loss_penalty_temperature': Array(-1.3896015, dtype=float32), 'loss_temperature': Array(0.8000437, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00076604, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.01007875, -0.0099753 ], dtype=float32), 'log_penalty_temperature': Array([-0.0070219], dtype=float32), 'log_temperature': Array([-0.00999835], dtype=float32), 'mean': Array([-0.0023564 , 0.00095438], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00999664, 0.00999665], dtype=float32), 'log_penalty_temperature': Array([0.00023387], dtype=float32), 'log_temperature': Array([0.09996644], dtype=float32), 'mean': Array([0.96598744, 0.9090329 ], dtype=float32)}\n",
"\t Slowdist: (Array([2.0116832, 2.0048318], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 300\n",
"\t Loss: 0.05826491117477417\n",
"\t Params: {'log_alpha': Array([6.987654 , 6.9963055], dtype=float32), 'log_penalty_temperature': Array([7.508599], dtype=float32), 'log_temperature': Array([6.9904213], dtype=float32), 'mean': Array([2.0133703, 2.021823 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6574119, dtype=float32), 'loss_penalty_temperature': Array(-1.4393355, dtype=float32), 'loss_temperature': Array(0.7001327, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00084413, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00999312, -0.009806 ], dtype=float32), 'log_penalty_temperature': Array([-0.00515563], dtype=float32), 'log_temperature': Array([-0.00999465], dtype=float32), 'mean': Array([ 0.00258124, -0.00360387], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00999087, 0.00999094], dtype=float32), 'log_penalty_temperature': Array([0.00015579], dtype=float32), 'log_temperature': Array([0.09990893], dtype=float32), 'mean': Array([-0.58560795, 1.0779083 ], dtype=float32)}\n",
"\t Slowdist: (Array([2.010789 , 2.0254269], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 400\n",
"\t Loss: 0.30656659603118896\n",
"\t Params: {'log_alpha': Array([5.989863, 5.999579], dtype=float32), 'log_penalty_temperature': Array([7.0235324], dtype=float32), 'log_temperature': Array([5.991402], dtype=float32), 'mean': Array([2.0024784, 2.01043 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.0522367, dtype=float32), 'loss_penalty_temperature': Array(-1.4661998, dtype=float32), 'loss_temperature': Array(0.60038584, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00114513, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00992742, -0.01008575], dtype=float32), 'log_penalty_temperature': Array([-0.00274716], dtype=float32), 'log_temperature': Array([-0.00998401], dtype=float32), 'mean': Array([ 0.00230019, -0.00075393], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00997527, 0.00997551], dtype=float32), 'log_penalty_temperature': Array([-0.00014494], dtype=float32), 'log_temperature': Array([0.09975307], dtype=float32), 'mean': Array([-0.43430525, -0.2679104 ], dtype=float32)}\n",
"\t Slowdist: (Array([2.000178, 2.011184], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 500\n",
"\t Loss: -0.056698769330978394\n",
"\t Params: {'log_alpha': Array([5.024377 , 5.0177097], dtype=float32), 'log_penalty_temperature': Array([6.7465873], dtype=float32), 'log_temperature': Array([4.9942617], dtype=float32), 'mean': Array([2.0223832, 2.0226328], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7662944, dtype=float32), 'loss_penalty_temperature': Array(-1.424832, dtype=float32), 'loss_temperature': Array(0.5010903, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00099728, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.01003663, -0.00970492], dtype=float32), 'log_penalty_temperature': Array([-0.00167181], dtype=float32), 'log_temperature': Array([-0.00995409], dtype=float32), 'mean': Array([-0.00199603, 0.00121003], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00993532, 0.00993487], dtype=float32), 'log_penalty_temperature': Array([2.6026719e-06], dtype=float32), 'log_temperature': Array([0.09933352], dtype=float32), 'mean': Array([ 1.3728685 , -0.03071524], dtype=float32)}\n",
"\t Slowdist: (Array([2.0243793, 2.0214229], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 600\n",
"\t Loss: -0.37505072355270386\n",
"\t Params: {'log_alpha': Array([4.0242906, 4.0203586], dtype=float32), 'log_penalty_temperature': Array([6.687692], dtype=float32), 'log_temperature': Array([4.0023184], dtype=float32), 'mean': Array([2.0034206, 2.006138 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5969576, dtype=float32), 'loss_penalty_temperature': Array(-1.4560175, dtype=float32), 'loss_temperature': Array(0.40301228, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00088668, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00988557, -0.00999509], dtype=float32), 'log_penalty_temperature': Array([-0.00236257], dtype=float32), 'log_temperature': Array([-0.00987229], dtype=float32), 'mean': Array([-0.00310783, 0.00024052], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00982608, 0.00982542], dtype=float32), 'log_penalty_temperature': Array([0.00011326], dtype=float32), 'log_temperature': Array([0.09822278], dtype=float32), 'mean': Array([ 0.03735625, -0.45226246], dtype=float32)}\n",
"\t Slowdist: (Array([2.0065284, 2.0058975], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 700\n",
"\t Loss: -0.4723517894744873\n",
"\t Params: {'log_alpha': Array([3.0500345, 3.0384471], dtype=float32), 'log_penalty_temperature': Array([6.555307], dtype=float32), 'log_temperature': Array([3.024243], dtype=float32), 'mean': Array([1.9649576, 1.9900335], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.53500116, dtype=float32), 'loss_penalty_temperature': Array(-1.3774451, dtype=float32), 'loss_temperature': Array(0.30809072, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00083905, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00975942, -0.00971999], dtype=float32), 'log_penalty_temperature': Array([-0.00033438], dtype=float32), 'log_temperature': Array([-0.00965884], dtype=float32), 'mean': Array([-0.00261138, 0.00052281], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00955203, 0.00954703], dtype=float32), 'log_penalty_temperature': Array([0.00016078], dtype=float32), 'log_temperature': Array([0.09540825], dtype=float32), 'mean': Array([0.55228657, 0.34433275], dtype=float32)}\n",
"\t Slowdist: (Array([1.967569 , 1.9895107], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 800\n",
"\t Loss: -0.37934017181396484\n",
"\t Params: {'log_alpha': Array([2.1406171, 2.1088328], dtype=float32), 'log_penalty_temperature': Array([6.643746], dtype=float32), 'log_temperature': Array([2.0805252], dtype=float32), 'mean': Array([1.9487295, 1.975047 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7039587, dtype=float32), 'loss_penalty_temperature': Array(-1.3488435, dtype=float32), 'loss_temperature': Array(0.22063307, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00098052, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00868134, -0.0092011 ], dtype=float32), 'log_penalty_temperature': Array([0.00221696], dtype=float32), 'log_temperature': Array([-0.00915507], dtype=float32), 'mean': Array([ 0.00253179, -0.00010385], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00895603, 0.00892644], dtype=float32), 'log_penalty_temperature': Array([1.9493105e-05], dtype=float32), 'log_temperature': Array([0.08898961], dtype=float32), 'mean': Array([0.60005254, 0.42159754], dtype=float32)}\n",
"\t Slowdist: (Array([1.9461977, 1.9751508], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 900\n",
"\t Loss: -0.47689417004585266\n",
"\t Params: {'log_alpha': Array([1.2941847, 1.2398111], dtype=float32), 'log_penalty_temperature': Array([6.6119533], dtype=float32), 'log_temperature': Array([1.209866], dtype=float32), 'mean': Array([1.9591954, 1.969928 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.672686, dtype=float32), 'loss_penalty_temperature': Array(-1.327716, dtype=float32), 'loss_temperature': Array(0.14771825, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00099787, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00623151, -0.00824273], dtype=float32), 'log_penalty_temperature': Array([0.00071753], dtype=float32), 'log_temperature': Array([-0.00818303], dtype=float32), 'mean': Array([-0.00096734, -0.00164841], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00785905, 0.00776963], dtype=float32), 'log_penalty_temperature': Array([2.1111462e-06], dtype=float32), 'log_temperature': Array([0.077172], dtype=float32), 'mean': Array([1.0913274 , 0.76623327], dtype=float32)}\n",
"\t Slowdist: (Array([1.9601628, 1.9715765], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1000\n",
"\t Loss: -0.43849456310272217\n",
"\t Params: {'log_alpha': Array([0.5829749 , 0.55818796], dtype=float32), 'log_penalty_temperature': Array([6.5943894], dtype=float32), 'log_temperature': Array([0.45890638], dtype=float32), 'mean': Array([1.8935105, 1.9568812], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7864419, dtype=float32), 'loss_penalty_temperature': Array(-1.3406776, dtype=float32), 'loss_temperature': Array(0.09528756, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00095006, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00684806, -0.00586833], dtype=float32), 'log_penalty_temperature': Array([-0.00050934], dtype=float32), 'log_temperature': Array([-0.00681043], dtype=float32), 'mean': Array([-3.3961865e-04, 1.9497038e-05], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00643325, 0.00637391], dtype=float32), 'log_penalty_temperature': Array([4.9923146e-05], dtype=float32), 'log_temperature': Array([0.06143695], dtype=float32), 'mean': Array([ 0.02882566, -0.14189371], dtype=float32)}\n",
"\t Slowdist: (Array([1.8938501, 1.9568616], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1100\n",
"\t Loss: -0.7183690071105957\n",
"\t Params: {'log_alpha': Array([ 0.07190983, -0.06525106], dtype=float32), 'log_penalty_temperature': Array([6.5413065], dtype=float32), 'log_temperature': Array([-0.15040292], dtype=float32), 'mean': Array([1.8542974, 1.9351246], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.47202608, dtype=float32), 'loss_penalty_temperature': Array(-1.2666891, dtype=float32), 'loss_temperature': Array(0.06232807, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00064767, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00587423, -0.00565061], dtype=float32), 'log_penalty_temperature': Array([-0.00246504], dtype=float32), 'log_temperature': Array([-0.00541953], dtype=float32), 'mean': Array([ 0.00384338, -0.00052311], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00519436, 0.00485104], dtype=float32), 'log_penalty_temperature': Array([0.0003519], dtype=float32), 'log_temperature': Array([0.04638175], dtype=float32), 'mean': Array([0.16494276, 0.64638025], dtype=float32)}\n",
"\t Slowdist: (Array([1.850454 , 1.9356477], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1200\n",
"\t Loss: -0.33883675932884216\n",
"\t Params: {'log_alpha': Array([-0.30658388, -0.4717605 ], dtype=float32), 'log_penalty_temperature': Array([6.5475793], dtype=float32), 'log_temperature': Array([-0.6330766], dtype=float32), 'mean': Array([1.7614677, 1.9047223], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7931469, dtype=float32), 'loss_penalty_temperature': Array(-1.185104, dtype=float32), 'loss_temperature': Array(0.04273852, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00108359, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00244294, -0.00195345], dtype=float32), 'log_penalty_temperature': Array([0.00066953], dtype=float32), 'log_temperature': Array([-0.00429875], dtype=float32), 'mean': Array([ 0.00024119, -0.00117829], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00424546, 0.00384662], dtype=float32), 'log_penalty_temperature': Array([-8.349264e-05], dtype=float32), 'log_temperature': Array([0.03477877], dtype=float32), 'mean': Array([0.18388186, 0.5941939 ], dtype=float32)}\n",
"\t Slowdist: (Array([1.7612265, 1.9059006], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1300\n",
"\t Loss: -0.07564980536699295\n",
"\t Params: {'log_alpha': Array([-0.71924096, -0.81395394], dtype=float32), 'log_penalty_temperature': Array([6.591468], dtype=float32), 'log_temperature': Array([-1.0193971], dtype=float32), 'mean': Array([1.6711799, 1.8684354], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.0575967, dtype=float32), 'loss_penalty_temperature': Array(-1.1717963, dtype=float32), 'loss_temperature': Array(0.03090061, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00096734, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00243512, -0.00158517], dtype=float32), 'log_penalty_temperature': Array([1.16039855e-05], dtype=float32), 'log_temperature': Array([-0.00348137], dtype=float32), 'mean': Array([-0.00124618, -0.00151607], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00328097, 0.00307386], dtype=float32), 'log_penalty_temperature': Array([3.264638e-05], dtype=float32), 'log_temperature': Array([0.02658238], dtype=float32), 'mean': Array([-0.88345385, -0.17027438], dtype=float32)}\n",
"\t Slowdist: (Array([1.6724261, 1.8699515], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1400\n",
"\t Loss: -0.08950022608041763\n",
"\t Params: {'log_alpha': Array([-1.0005636, -1.031226 ], dtype=float32), 'log_penalty_temperature': Array([6.509834], dtype=float32), 'log_temperature': Array([-1.3363887], dtype=float32), 'mean': Array([1.6164894, 1.7372513], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.88053566, dtype=float32), 'loss_penalty_temperature': Array(-0.99962384, dtype=float32), 'loss_temperature': Array(0.02339294, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00096601, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00263337, -0.00275709], dtype=float32), 'log_penalty_temperature': Array([0.00052862], dtype=float32), 'log_temperature': Array([-0.00289643], dtype=float32), 'mean': Array([-0.00088033, 0.00037097], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00269349, 0.00263381], dtype=float32), 'log_penalty_temperature': Array([3.3951907e-05], dtype=float32), 'log_temperature': Array([0.02085823], dtype=float32), 'mean': Array([0.01975106, 0.16146824], dtype=float32)}\n",
"\t Slowdist: (Array([1.6173698, 1.7368803], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1500\n",
"\t Loss: 0.10572590678930283\n",
"\t Params: {'log_alpha': Array([-1.1550199, -1.2760484], dtype=float32), 'log_penalty_temperature': Array([6.415121], dtype=float32), 'log_temperature': Array([-1.6034839], dtype=float32), 'mean': Array([1.4417028, 1.6181877], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.87069184, dtype=float32), 'loss_penalty_temperature': Array(-0.78855175, dtype=float32), 'loss_temperature': Array(0.01837307, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00109604, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0031911 , -0.00204667], dtype=float32), 'log_penalty_temperature': Array([0.0003904], dtype=float32), 'log_temperature': Array([-0.00247143], dtype=float32), 'mean': Array([-0.00176061, 0.00140082], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00240155, 0.00218573], dtype=float32), 'log_penalty_temperature': Array([-9.5972224e-05], dtype=float32), 'log_temperature': Array([0.01678402], dtype=float32), 'mean': Array([0.7196689 , 0.06541317], dtype=float32)}\n",
"\t Slowdist: (Array([1.4434634, 1.6167868], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1600\n",
"\t Loss: 0.3289426267147064\n",
"\t Params: {'log_alpha': Array([-1.3779562, -1.5385824], dtype=float32), 'log_penalty_temperature': Array([6.3686275], dtype=float32), 'log_temperature': Array([-1.8338917], dtype=float32), 'mean': Array([1.2830744, 1.5767502], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.9890636, dtype=float32), 'loss_penalty_temperature': Array(-0.6791763, dtype=float32), 'loss_temperature': Array(0.01485365, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00090256, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00278836, -0.00169048], dtype=float32), 'log_penalty_temperature': Array([-0.00243795], dtype=float32), 'log_temperature': Array([-0.00215465], dtype=float32), 'mean': Array([-0.00366673, 0.00028807], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00201786, 0.00176988], dtype=float32), 'log_penalty_temperature': Array([9.730752e-05], dtype=float32), 'log_temperature': Array([0.01380315], dtype=float32), 'mean': Array([0.6677026 , 0.62343144], dtype=float32)}\n",
"\t Slowdist: (Array([1.2867411, 1.576462 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1700\n",
"\t Loss: 0.26487991213798523\n",
"\t Params: {'log_alpha': Array([-1.6104323, -1.6123283], dtype=float32), 'log_penalty_temperature': Array([6.2255354], dtype=float32), 'log_temperature': Array([-2.0365977], dtype=float32), 'mean': Array([1.2099762, 1.3736261], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.80858874, dtype=float32), 'loss_penalty_temperature': Array(-0.55962497, dtype=float32), 'loss_temperature': Array(0.0122856, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00099774, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00012469, 0.00579932], dtype=float32), 'log_penalty_temperature': Array([-0.0016098], dtype=float32), 'log_temperature': Array([-0.00191217], dtype=float32), 'mean': Array([ 0.00558323, -0.00102937], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00166546, 0.00165463], dtype=float32), 'log_penalty_temperature': Array([2.2808442e-06], dtype=float32), 'log_temperature': Array([0.0115609], dtype=float32), 'mean': Array([-0.9892728, -0.9986051], dtype=float32)}\n",
"\t Slowdist: (Array([1.2043929, 1.3746555], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1800\n",
"\t Loss: 0.71905118227005\n",
"\t Params: {'log_alpha': Array([-1.813717 , -1.7590309], dtype=float32), 'log_penalty_temperature': Array([5.8305917], dtype=float32), 'log_temperature': Array([-2.2178447], dtype=float32), 'mean': Array([1.173601 , 1.2468069], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.1764243, dtype=float32), 'loss_penalty_temperature': Array(-0.47082826, dtype=float32), 'loss_temperature': Array(0.01034867, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00119055, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00244736, -0.00255123], dtype=float32), 'log_penalty_temperature': Array([-0.00272762], dtype=float32), 'log_temperature': Array([-0.00172199], dtype=float32), 'mean': Array([-0.00032456, 0.00107128], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00140485, 0.00147232], dtype=float32), 'log_penalty_temperature': Array([-0.00018994], dtype=float32), 'log_temperature': Array([0.0098312], dtype=float32), 'mean': Array([-0.06363957, -1.4840741 ], dtype=float32)}\n",
"\t Slowdist: (Array([1.1739256, 1.2457355], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 1900\n",
"\t Loss: 0.2700989544391632\n",
"\t Params: {'log_alpha': Array([-1.9843779, -1.984129 ], dtype=float32), 'log_penalty_temperature': Array([5.828775], dtype=float32), 'log_temperature': Array([-2.3820794], dtype=float32), 'mean': Array([1.2608198, 1.2127216], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7113207, dtype=float32), 'loss_penalty_temperature': Array(-0.45264933, dtype=float32), 'loss_temperature': Array(0.00884718, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00095067, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00210317, -0.00122771], dtype=float32), 'log_penalty_temperature': Array([-0.00220548], dtype=float32), 'log_temperature': Array([-0.00156959], dtype=float32), 'mean': Array([0.0010757 , 0.00107636], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00121077, 0.0012101 ], dtype=float32), 'log_penalty_temperature': Array([4.9154223e-05], dtype=float32), 'log_temperature': Array([0.00846711], dtype=float32), 'mean': Array([-0.3552358 , -0.26324248], dtype=float32)}\n",
"\t Slowdist: (Array([1.259744 , 1.2116452], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2000\n",
"\t Loss: 0.5639240741729736\n",
"\t Params: {'log_alpha': Array([-2.1136417, -2.142192 ], dtype=float32), 'log_penalty_temperature': Array([5.539783], dtype=float32), 'log_temperature': Array([-2.5325553], dtype=float32), 'mean': Array([1.0863805, 1.1239536], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.83131915, dtype=float32), 'loss_penalty_temperature': Array(-0.27730554, dtype=float32), 'loss_temperature': Array(0.00765634, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00088018, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00166906, -0.00181505], dtype=float32), 'log_penalty_temperature': Array([-0.00603382], dtype=float32), 'log_temperature': Array([-0.00144518], dtype=float32), 'mean': Array([-0.00394974, 0.00316915], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00107939, 0.00105234], dtype=float32), 'log_penalty_temperature': Array([0.00011934], dtype=float32), 'log_temperature': Array([0.00737058], dtype=float32), 'mean': Array([0.75778717, 0.56609565], dtype=float32)}\n",
"\t Slowdist: (Array([1.0903302, 1.1207844], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2100\n",
"\t Loss: 0.386263906955719\n",
"\t Params: {'log_alpha': Array([-2.2742171, -2.2462347], dtype=float32), 'log_penalty_temperature': Array([4.995144], dtype=float32), 'log_temperature': Array([-2.6717086], dtype=float32), 'mean': Array([0.99336016, 1.0540836 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.59873176, dtype=float32), 'loss_penalty_temperature': Array(-0.22115001, dtype=float32), 'loss_temperature': Array(0.00669358, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00090916, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00192246, -0.00198969], dtype=float32), 'log_penalty_temperature': Array([-0.00678547], dtype=float32), 'log_temperature': Array([-0.00134198], dtype=float32), 'mean': Array([-0.00134176, 0.00108523], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00093444, 0.00095847], dtype=float32), 'log_penalty_temperature': Array([9.023982e-05], dtype=float32), 'log_temperature': Array([0.00647448], dtype=float32), 'mean': Array([ 0.48044458, -1.0177667 ], dtype=float32)}\n",
"\t Slowdist: (Array([0.9947019, 1.0529983], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2200\n",
"\t Loss: 0.1849360316991806\n",
"\t Params: {'log_alpha': Array([-2.4002688, -2.4197261], dtype=float32), 'log_penalty_temperature': Array([4.5854897], dtype=float32), 'log_temperature': Array([-2.8014035], dtype=float32), 'mean': Array([0.85853904, 1.0204238 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.32170644, dtype=float32), 'loss_penalty_temperature': Array(-0.14439498, dtype=float32), 'loss_temperature': Array(0.00590243, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00076417, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00060121, -0.00170178], dtype=float32), 'log_penalty_temperature': Array([-0.00588559], dtype=float32), 'log_temperature': Array([-0.00125519], dtype=float32), 'mean': Array([-0.0001468, -0.0040167], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00083198, 0.00081809], dtype=float32), 'log_penalty_temperature': Array([0.00023342], dtype=float32), 'log_temperature': Array([0.00573162], dtype=float32), 'mean': Array([-0.55491453, 0.0286716 ], dtype=float32)}\n",
"\t Slowdist: (Array([0.85868585, 1.0244405 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2300\n",
"\t Loss: 0.4473106265068054\n",
"\t Params: {'log_alpha': Array([-2.540007, -2.483013], dtype=float32), 'log_penalty_temperature': Array([3.792916], dtype=float32), 'log_temperature': Array([-2.9230971], dtype=float32), 'mean': Array([0.8336165 , 0.88097906], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5349268, dtype=float32), 'loss_penalty_temperature': Array(-0.09442136, dtype=float32), 'loss_temperature': Array(0.00524316, dtype=float32), 'non_parametric_kl': Array(0., dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00076899, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 0.0002761 , -0.00165652], dtype=float32), 'log_penalty_temperature': Array([-0.00691665], dtype=float32), 'log_temperature': Array([-0.00118132], dtype=float32), 'mean': Array([0.00037895, 0.0003383 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00073082, 0.00077176], dtype=float32), 'log_penalty_temperature': Array([0.00022585], dtype=float32), 'log_temperature': Array([0.00510808], dtype=float32), 'mean': Array([ 0.5534362, -0.8750114], dtype=float32)}\n",
"\t Slowdist: (Array([0.8332375 , 0.88064075], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2400\n",
"\t Loss: 0.807602047920227\n",
"\t Params: {'log_alpha': Array([-2.6507826, -2.0357647], dtype=float32), 'log_penalty_temperature': Array([2.1145923], dtype=float32), 'log_temperature': Array([-2.9402852], dtype=float32), 'mean': Array([0.7194834 , 0.47410864], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.831452, dtype=float32), 'loss_penalty_temperature': Array(-0.03116937, dtype=float32), 'loss_temperature': Array(0.00540864, dtype=float32), 'non_parametric_kl': Array(0.00102571, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0007985, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00033454, -0.00092696], dtype=float32), 'log_penalty_temperature': Array([-0.01689054], dtype=float32), 'log_temperature': Array([-0.0010934], dtype=float32), 'mean': Array([-0.001219 , 0.00106522], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00065961, 0.00115593], dtype=float32), 'log_penalty_temperature': Array([0.00018012], dtype=float32), 'log_temperature': Array([0.00497344], dtype=float32), 'mean': Array([-0.7813669 , 0.48359105], dtype=float32)}\n",
"\t Slowdist: (Array([0.72070235, 0.4730434 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2500\n",
"\t Loss: 1.0776270627975464\n",
"\t Params: {'log_alpha': Array([-2.7716837, -0.1247569], dtype=float32), 'log_penalty_temperature': Array([1.5945727], dtype=float32), 'log_temperature': Array([-1.2691258], dtype=float32), 'mean': Array([ 0.66727597, -0.7238604 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.44543827, dtype=float32), 'loss_penalty_temperature': Array(-0.0440731, dtype=float32), 'loss_temperature': Array(0.6693616, dtype=float32), 'non_parametric_kl': Array(0.6907245, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00164663, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00145355, 0.00735338], dtype=float32), 'log_penalty_temperature': Array([0.0164979], dtype=float32), 'log_temperature': Array([0.02427382], dtype=float32), 'mean': Array([0.00040383, 0.00175113], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00058954, 0.0046702 ], dtype=float32), 'log_penalty_temperature': Array([-0.00053604], dtype=float32), 'log_temperature': Array([-0.12717003], dtype=float32), 'mean': Array([-0.8134592, 0.3306723], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.66687214, -0.7256115 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2600\n",
"\t Loss: 1.1849126815795898\n",
"\t Params: {'log_alpha': Array([-2.90543 , -0.13347557], dtype=float32), 'log_penalty_temperature': Array([2.2388756], dtype=float32), 'log_temperature': Array([-0.01948936], dtype=float32), 'mean': Array([ 0.6704495, -0.7375241], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6538851, dtype=float32), 'loss_penalty_temperature': Array(-0.05011825, dtype=float32), 'loss_temperature': Array(0.57432216, dtype=float32), 'non_parametric_kl': Array(0.14861362, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00107078, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00138012, -0.00082077], dtype=float32), 'log_penalty_temperature': Array([-0.00057357], dtype=float32), 'log_temperature': Array([0.00458906], dtype=float32), 'mean': Array([-0.00421568, 0.00028569], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00051954, 0.00466885], dtype=float32), 'log_penalty_temperature': Array([-6.403944e-05], dtype=float32), 'log_temperature': Array([-0.02401413], dtype=float32), 'mean': Array([1.4939733, 0.9184401], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.67466515, -0.7378098 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2700\n",
"\t Loss: 1.0896447896957397\n",
"\t Params: {'log_alpha': Array([-2.9969327 , -0.21730313], dtype=float32), 'log_penalty_temperature': Array([1.9763631], dtype=float32), 'log_temperature': Array([0.19751509], dtype=float32), 'mean': Array([ 0.58097744, -0.72799224], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.4862138, dtype=float32), 'loss_penalty_temperature': Array(-0.03220581, dtype=float32), 'loss_temperature': Array(0.6292412, dtype=float32), 'non_parametric_kl': Array(0.11710127, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00095759, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00135964, -0.00084123], dtype=float32), 'log_penalty_temperature': Array([0.00265772], dtype=float32), 'log_temperature': Array([0.00117455], dtype=float32), 'mean': Array([0.0028756 , 0.00068726], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00047626, 0.00446095], dtype=float32), 'log_penalty_temperature': Array([3.7217927e-05], dtype=float32), 'log_temperature': Array([-0.00938742], dtype=float32), 'mean': Array([-0.42442703, 0.02910042], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.5781018, -0.7286795], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2800\n",
"\t Loss: 1.1050753593444824\n",
"\t Params: {'log_alpha': Array([-3.050156 , -0.3000981], dtype=float32), 'log_penalty_temperature': Array([2.2251701], dtype=float32), 'log_temperature': Array([0.24977182], dtype=float32), 'mean': Array([ 0.6046459 , -0.74899316], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5832055, dtype=float32), 'loss_penalty_temperature': Array(-0.04146791, dtype=float32), 'loss_temperature': Array(0.557328, dtype=float32), 'non_parametric_kl': Array(0.09781045, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00091025, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00085679, -0.00084522], dtype=float32), 'log_penalty_temperature': Array([-0.0023049], dtype=float32), 'log_temperature': Array([-5.6389345e-05], dtype=float32), 'mean': Array([-0.00302355, -0.00158246], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00045248, 0.0042574 ], dtype=float32), 'log_penalty_temperature': Array([8.102108e-05], dtype=float32), 'log_temperature': Array([0.00123086], dtype=float32), 'mean': Array([1.0552318, 1.2339742], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.6076695, -0.7474107], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 2900\n",
"\t Loss: 1.1336313486099243\n",
"\t Params: {'log_alpha': Array([-3.1656106, -0.3838625], dtype=float32), 'log_penalty_temperature': Array([1.9065849], dtype=float32), 'log_temperature': Array([0.26086798], dtype=float32), 'mean': Array([ 0.55069506, -0.7249758 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.49479327, dtype=float32), 'loss_penalty_temperature': Array(-0.02007033, dtype=float32), 'loss_temperature': Array(0.65329623, dtype=float32), 'non_parametric_kl': Array(0.09038533, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00061032, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00113085, -0.00079649], dtype=float32), 'log_penalty_temperature': Array([-0.00456058], dtype=float32), 'log_temperature': Array([0.00012203], dtype=float32), 'mean': Array([-2.933957e-03, 8.031288e-05], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00040525, 0.00405388], dtype=float32), 'log_penalty_temperature': Array([0.00033952], dtype=float32), 'log_temperature': Array([0.00543056], dtype=float32), 'mean': Array([ 1.3653088 , -0.13983467], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.55362904, -0.7250561 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3000\n",
"\t Loss: 1.3241337537765503\n",
"\t Params: {'log_alpha': Array([-3.2374494 , -0.46905267], dtype=float32), 'log_penalty_temperature': Array([1.7561283], dtype=float32), 'log_temperature': Array([0.27198327], dtype=float32), 'mean': Array([ 0.4894797 , -0.72456414], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8379046, dtype=float32), 'loss_penalty_temperature': Array(-0.04172703, dtype=float32), 'loss_temperature': Array(0.5227087, dtype=float32), 'non_parametric_kl': Array(0.1108341, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00112506, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0011386 , -0.00082004], dtype=float32), 'log_penalty_temperature': Array([-0.00068704], dtype=float32), 'log_temperature': Array([-0.00011344], dtype=float32), 'mean': Array([-0.00121486, -0.00156431], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00037822, 0.00385035], dtype=float32), 'log_penalty_temperature': Array([-0.00010664], dtype=float32), 'log_temperature': Array([-0.0061495], dtype=float32), 'mean': Array([0.31977478, 0.7613824 ], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.49069455, -0.7229998 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3100\n",
"\t Loss: 1.0049031972885132\n",
"\t Params: {'log_alpha': Array([-3.3247452, -0.5532442], dtype=float32), 'log_penalty_temperature': Array([1.9008057], dtype=float32), 'log_temperature': Array([0.25763676], dtype=float32), 'mean': Array([ 0.61811864, -0.71893704], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.43998265, dtype=float32), 'loss_penalty_temperature': Array(-0.01837213, dtype=float32), 'loss_temperature': Array(0.57839257, dtype=float32), 'non_parametric_kl': Array(0.09526702, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00046313, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00112125, -0.00085481], dtype=float32), 'log_penalty_temperature': Array([0.00314032], dtype=float32), 'log_temperature': Array([-4.903485e-05], dtype=float32), 'mean': Array([0.00244115, 0.00095713], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.0003477, 0.0036531], dtype=float32), 'log_penalty_temperature': Array([0.00046684], dtype=float32), 'log_temperature': Array([0.0026697], dtype=float32), 'mean': Array([ 0.47684246, -0.9267983 ], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.6156775, -0.7198942], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3200\n",
"\t Loss: 1.0701980590820312\n",
"\t Params: {'log_alpha': Array([-3.4265258, -0.6381286], dtype=float32), 'log_penalty_temperature': Array([1.8629454], dtype=float32), 'log_temperature': Array([0.2606489], dtype=float32), 'mean': Array([ 0.52374595, -0.7276362 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5222799, dtype=float32), 'loss_penalty_temperature': Array(-0.03600911, dtype=float32), 'loss_temperature': Array(0.5793633, dtype=float32), 'non_parametric_kl': Array(0.10843924, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00137433, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 0.00030086, -0.0008235 ], dtype=float32), 'log_penalty_temperature': Array([-0.00200374], dtype=float32), 'log_temperature': Array([0.00020241], dtype=float32), 'mean': Array([ 0.00032308, -0.00079445], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00031468, 0.00345856], dtype=float32), 'log_penalty_temperature': Array([-0.00032414], dtype=float32), 'log_temperature': Array([-0.00476607], dtype=float32), 'mean': Array([-1.5116683 , 0.48110563], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.5234229 , -0.72684175], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3300\n",
"\t Loss: 1.2888963222503662\n",
"\t Params: {'log_alpha': Array([-3.4827292, -0.723467 ], dtype=float32), 'log_penalty_temperature': Array([1.841365], dtype=float32), 'log_temperature': Array([0.26143354], dtype=float32), 'mean': Array([ 0.58266836, -0.7255762 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7518214, dtype=float32), 'loss_penalty_temperature': Array(-0.0330327, dtype=float32), 'loss_temperature': Array(0.56584734, dtype=float32), 'non_parametric_kl': Array(0.11011819, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00085835, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00112124, -0.0008573 ], dtype=float32), 'log_penalty_temperature': Array([-0.00145335], dtype=float32), 'log_temperature': Array([5.8486923e-05], dtype=float32), 'mean': Array([-0.00024395, -0.00027 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.0002984 , 0.00326819], dtype=float32), 'log_penalty_temperature': Array([0.00012231], dtype=float32), 'log_temperature': Array([-0.00571653], dtype=float32), 'mean': Array([-0.061869 , -0.0683924], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.5829123, -0.7253062], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3400\n",
"\t Loss: 1.181031346321106\n",
"\t Params: {'log_alpha': Array([-3.5839734, -0.8049648], dtype=float32), 'log_penalty_temperature': Array([1.7659545], dtype=float32), 'log_temperature': Array([0.2601185], dtype=float32), 'mean': Array([ 0.46356493, -0.7305318 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5776532, dtype=float32), 'loss_penalty_temperature': Array(-0.02271618, dtype=float32), 'loss_temperature': Array(0.62212217, dtype=float32), 'non_parametric_kl': Array(0.10688792, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00092324, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00034355, -0.00084434], dtype=float32), 'log_penalty_temperature': Array([-0.00100898], dtype=float32), 'log_temperature': Array([0.00043917], dtype=float32), 'mean': Array([-0.00444987, 0.00079763], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00027024, 0.00309145], dtype=float32), 'log_penalty_temperature': Array([6.551496e-05], dtype=float32), 'log_temperature': Array([-0.00388864], dtype=float32), 'mean': Array([ 0.25994802, -0.5126397 ], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.4680148 , -0.73132944], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3500\n",
"\t Loss: 1.2093379497528076\n",
"\t Params: {'log_alpha': Array([-3.6777506 , -0.88798785], dtype=float32), 'log_penalty_temperature': Array([1.8160655], dtype=float32), 'log_temperature': Array([0.25180104], dtype=float32), 'mean': Array([ 0.44064492, -0.7212805 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.73009145, dtype=float32), 'loss_penalty_temperature': Array(-0.02735094, dtype=float32), 'loss_temperature': Array(0.5028988, dtype=float32), 'non_parametric_kl': Array(0.10871886, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00071977, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00059779, -0.00083397], dtype=float32), 'log_penalty_temperature': Array([-0.00296857], dtype=float32), 'log_temperature': Array([-8.2723756e-07], dtype=float32), 'mean': Array([-0.00175002, 0.00113147], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00024671, 0.00291698], dtype=float32), 'log_penalty_temperature': Array([0.00024108], dtype=float32), 'log_temperature': Array([-0.00490544], dtype=float32), 'mean': Array([-0.8181232 , -0.34866244], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.44239494, -0.722412 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3600\n",
"\t Loss: 1.2618625164031982\n",
"\t Params: {'log_alpha': Array([-3.7612443 , -0.97024506], dtype=float32), 'log_penalty_temperature': Array([1.5728055], dtype=float32), 'log_temperature': Array([0.2610512], dtype=float32), 'mean': Array([ 0.42576486, -0.69376504], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7618042, dtype=float32), 'loss_penalty_temperature': Array(-0.03891052, dtype=float32), 'loss_temperature': Array(0.5355232, dtype=float32), 'non_parametric_kl': Array(0.09681669, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00134723, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00063297, -0.00074933], dtype=float32), 'log_penalty_temperature': Array([-0.00064625], dtype=float32), 'log_temperature': Array([5.994808e-05], dtype=float32), 'mean': Array([0.00238505, 0.00096242], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.0002274 , 0.00274981], dtype=float32), 'log_penalty_temperature': Array([-0.00028766], dtype=float32), 'log_temperature': Array([0.00179814], dtype=float32), 'mean': Array([-0.00766396, 1.1881571 ], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.4233798, -0.6947275], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3700\n",
"\t Loss: 1.1228140592575073\n",
"\t Params: {'log_alpha': Array([-3.851477 , -1.0469788], dtype=float32), 'log_penalty_temperature': Array([1.5722681], dtype=float32), 'log_temperature': Array([0.26480272], dtype=float32), 'mean': Array([ 0.4433347 , -0.70881426], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.52717745, dtype=float32), 'loss_penalty_temperature': Array(-0.02492876, dtype=float32), 'loss_temperature': Array(0.61734444, dtype=float32), 'non_parametric_kl': Array(0.10197209, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00100559, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00076744, -0.00081917], dtype=float32), 'log_penalty_temperature': Array([0.00473213], dtype=float32), 'log_temperature': Array([-0.00014354], dtype=float32), 'mean': Array([-0.00025343, -0.00068419], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00020822, 0.00259963], dtype=float32), 'log_penalty_temperature': Array([-4.641458e-06], dtype=float32), 'log_temperature': Array([-0.00111585], dtype=float32), 'mean': Array([0.1140548 , 0.01734571], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.44358814, -0.70813006], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3800\n",
"\t Loss: 0.9663702845573425\n",
"\t Params: {'log_alpha': Array([-3.9443755, -1.1271405], dtype=float32), 'log_penalty_temperature': Array([1.6506819], dtype=float32), 'log_temperature': Array([0.26421005], dtype=float32), 'mean': Array([ 0.46401244, -0.7176855 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.47267574, dtype=float32), 'loss_penalty_temperature': Array(-0.02276553, dtype=float32), 'loss_temperature': Array(0.5134599, dtype=float32), 'non_parametric_kl': Array(0.10405714, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00056215, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00085301, -0.00079027], dtype=float32), 'log_penalty_temperature': Array([-0.00293026], dtype=float32), 'log_temperature': Array([0.00012505], dtype=float32), 'mean': Array([ 0.00014212, -0.00105504], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00019011, 0.00244835], dtype=float32), 'log_penalty_temperature': Array([0.0003675], dtype=float32), 'log_temperature': Array([-0.00229491], dtype=float32), 'mean': Array([-0.5024124 , -0.08042979], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.46387032, -0.71663046], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 3900\n",
"\t Loss: 1.122465968132019\n",
"\t Params: {'log_alpha': Array([-4.027606 , -1.1944473], dtype=float32), 'log_penalty_temperature': Array([1.9318701], dtype=float32), 'log_temperature': Array([0.26118097], dtype=float32), 'mean': Array([ 0.53312 , -0.7526177], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.57634616, dtype=float32), 'loss_penalty_temperature': Array(-0.02874156, dtype=float32), 'loss_temperature': Array(0.5720372, dtype=float32), 'non_parametric_kl': Array(0.09460305, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00073378, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00050246, -0.00078448], dtype=float32), 'log_penalty_temperature': Array([-0.00052318], dtype=float32), 'log_temperature': Array([0.00029714], dtype=float32), 'mean': Array([-0.0016112 , 0.00095536], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00017514, 0.00232604], dtype=float32), 'log_penalty_temperature': Array([0.00023249], dtype=float32), 'log_temperature': Array([0.00304851], dtype=float32), 'mean': Array([-0.03702972, -0.34394398], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.53473115, -0.75357306], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4000\n",
"\t Loss: 1.2240368127822876\n",
"\t Params: {'log_alpha': Array([-4.0519753, -1.2736094], dtype=float32), 'log_penalty_temperature': Array([1.6932304], dtype=float32), 'log_temperature': Array([0.27610442], dtype=float32), 'mean': Array([ 0.41210034, -0.7141145 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7755399, dtype=float32), 'loss_penalty_temperature': Array(-0.04049371, dtype=float32), 'loss_temperature': Array(0.48634902, dtype=float32), 'non_parametric_kl': Array(0.09594899, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00150162, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00086878, -0.00081985], dtype=float32), 'log_penalty_temperature': Array([-0.00319381], dtype=float32), 'log_temperature': Array([0.0001152], dtype=float32), 'mean': Array([ 0.00258497, -0.00011732], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00017105, 0.0021878 ], dtype=float32), 'log_penalty_temperature': Array([-0.00042386], dtype=float32), 'log_temperature': Array([0.00230322], dtype=float32), 'mean': Array([-0.76216596, 0.6284508 ], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.40951538, -0.7139972 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4100\n",
"\t Loss: 1.1584421396255493\n",
"\t Params: {'log_alpha': Array([-4.1143885, -1.3473631], dtype=float32), 'log_penalty_temperature': Array([1.8222681], dtype=float32), 'log_temperature': Array([0.26632312], dtype=float32), 'mean': Array([ 0.32905233, -0.7514241 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5818955, dtype=float32), 'loss_penalty_temperature': Array(-0.03542287, dtype=float32), 'loss_temperature': Array(0.6094952, dtype=float32), 'non_parametric_kl': Array(0.0970925, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00091967, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00080697, -0.00079886], dtype=float32), 'log_penalty_temperature': Array([-0.00224021], dtype=float32), 'log_temperature': Array([-0.00015753], dtype=float32), 'mean': Array([-0.00361085, 0.00064184], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00016086, 0.00206433], dtype=float32), 'log_penalty_temperature': Array([6.915838e-05], dtype=float32), 'log_temperature': Array([0.00164631], dtype=float32), 'mean': Array([-0.13039285, 0.53584063], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.33266318, -0.7520659 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4200\n",
"\t Loss: 1.2270063161849976\n",
"\t Params: {'log_alpha': Array([-4.185161 , -1.4265808], dtype=float32), 'log_penalty_temperature': Array([1.710859], dtype=float32), 'log_temperature': Array([0.27631286], dtype=float32), 'mean': Array([ 0.16631529, -0.7396459 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6433314, dtype=float32), 'loss_penalty_temperature': Array(-0.0273851, dtype=float32), 'loss_temperature': Array(0.6087552, dtype=float32), 'non_parametric_kl': Array(0.10068294, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0010332, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00069026, -0.00080192], dtype=float32), 'log_penalty_temperature': Array([0.00342814], dtype=float32), 'log_temperature': Array([0.00097706], dtype=float32), 'mean': Array([-0.00059863, -0.00081224], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00015002, 0.00193757], dtype=float32), 'log_penalty_temperature': Array([-2.8080945e-05], dtype=float32), 'log_temperature': Array([-0.0003881], dtype=float32), 'mean': Array([-0.22788152, -0.28456947], dtype=float32)}\n",
"\t Slowdist: (Array([ 0.16691391, -0.73883367], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4300\n",
"\t Loss: 1.6039652824401855\n",
"\t Params: {'log_alpha': Array([-4.2492313, -1.5023878], dtype=float32), 'log_penalty_temperature': Array([1.6711892], dtype=float32), 'log_temperature': Array([0.33737296], dtype=float32), 'mean': Array([-0.02048574, -0.72920686], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.0335717, dtype=float32), 'loss_penalty_temperature': Array(-0.02639404, dtype=float32), 'loss_temperature': Array(0.5946347, dtype=float32), 'non_parametric_kl': Array(0.09536084, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00127711, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 0.00034723, -0.00077427], dtype=float32), 'log_penalty_temperature': Array([0.00069473], dtype=float32), 'log_temperature': Array([0.00202298], dtype=float32), 'mean': Array([-0.00384609, 0.00053002], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00014069, 0.00182185], dtype=float32), 'log_penalty_temperature': Array([-0.00023326], dtype=float32), 'log_temperature': Array([0.00270486], dtype=float32), 'mean': Array([0.42479032, 0.11346817], dtype=float32)}\n",
"\t Slowdist: (Array([-0.01663966, -0.72973686], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4400\n",
"\t Loss: 1.514448642730713\n",
"\t Params: {'log_alpha': Array([-3.5384283, -1.5791984], dtype=float32), 'log_penalty_temperature': Array([2.0800655], dtype=float32), 'log_temperature': Array([0.8126696], dtype=float32), 'mean': Array([-0.744469 , -0.7174425], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.50046575, dtype=float32), 'loss_penalty_temperature': Array(-0.04429352, dtype=float32), 'loss_temperature': Array(1.0561153, dtype=float32), 'non_parametric_kl': Array(0.11013103, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00080171, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 0.00274924, -0.000682 ], dtype=float32), 'log_penalty_temperature': Array([0.00855682], dtype=float32), 'log_temperature': Array([0.00104446], dtype=float32), 'mean': Array([0.00032547, 0.00299666], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00028163, 0.00171006], dtype=float32), 'log_penalty_temperature': Array([0.00017609], dtype=float32), 'log_temperature': Array([-0.00701553], dtype=float32), 'mean': Array([-0.19172704, -0.80204546], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7447944, -0.7204392], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4500\n",
"\t Loss: 1.663645625114441\n",
"\t Params: {'log_alpha': Array([-3.6635332, -1.6502742], dtype=float32), 'log_penalty_temperature': Array([2.4568334], dtype=float32), 'log_temperature': Array([0.8212998], dtype=float32), 'mean': Array([-0.75720745, -0.71485555], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.68455225, dtype=float32), 'loss_penalty_temperature': Array(-0.06745453, dtype=float32), 'loss_temperature': Array(1.0445368, dtype=float32), 'non_parametric_kl': Array(0.09013426, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00132168, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00140537, -0.00076152], dtype=float32), 'log_penalty_temperature': Array([0.00207205], dtype=float32), 'log_temperature': Array([0.00068202], dtype=float32), 'mean': Array([ 0.00146965, -0.00067934], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00025035, 0.00161175], dtype=float32), 'log_penalty_temperature': Array([-0.00029617], dtype=float32), 'log_temperature': Array([0.00685032], dtype=float32), 'mean': Array([-0.10970869, 0.36649236], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7586771 , -0.71417624], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4600\n",
"\t Loss: 1.4971905946731567\n",
"\t Params: {'log_alpha': Array([-3.8013675, -1.7223854], dtype=float32), 'log_penalty_temperature': Array([2.5233548], dtype=float32), 'log_temperature': Array([0.8108527], dtype=float32), 'mean': Array([-0.71630305, -0.7579638 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5321878, dtype=float32), 'loss_penalty_temperature': Array(-0.0625668, dtype=float32), 'loss_temperature': Array(1.0257037, dtype=float32), 'non_parametric_kl': Array(0.10401039, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00092551, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00133337, -0.00073467], dtype=float32), 'log_penalty_temperature': Array([-0.00116487], dtype=float32), 'log_temperature': Array([-0.00098486], dtype=float32), 'mean': Array([-0.00130582, 0.0013094 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00021881, 0.00151659], dtype=float32), 'log_penalty_temperature': Array([6.891848e-05], dtype=float32), 'log_temperature': Array([-0.00277725], dtype=float32), 'mean': Array([ 0.70091206, -0.88403124], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71499723, -0.7592732 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4700\n",
"\t Loss: 1.3900542259216309\n",
"\t Params: {'log_alpha': Array([-3.9257853, -1.7923344], dtype=float32), 'log_penalty_temperature': Array([2.3465612], dtype=float32), 'log_temperature': Array([0.81377214], dtype=float32), 'mean': Array([-0.6836508 , -0.72696185], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.32857886, dtype=float32), 'loss_penalty_temperature': Array(-0.03074341, dtype=float32), 'loss_temperature': Array(1.0904816, dtype=float32), 'non_parametric_kl': Array(0.10139387, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00064895, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00121205, -0.00057959], dtype=float32), 'log_penalty_temperature': Array([-0.00482991], dtype=float32), 'log_temperature': Array([0.00014371], dtype=float32), 'mean': Array([0.00015078, 0.00077108], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00019368, 0.00142858], dtype=float32), 'log_penalty_temperature': Array([0.00032046], dtype=float32), 'log_temperature': Array([-0.00096583], dtype=float32), 'mean': Array([0.3600305 , 0.02879279], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6838016 , -0.72773296], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4800\n",
"\t Loss: 1.6036906242370605\n",
"\t Params: {'log_alpha': Array([-4.0530195, -1.8539323], dtype=float32), 'log_penalty_temperature': Array([2.226617], dtype=float32), 'log_temperature': Array([0.79812455], dtype=float32), 'mean': Array([-0.66787577, -0.7554579 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7102695, dtype=float32), 'loss_penalty_temperature': Array(-0.04689189, dtype=float32), 'loss_temperature': Array(0.9386847, dtype=float32), 'non_parametric_kl': Array(0.09693679, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00102954, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00130544, -0.00058217], dtype=float32), 'log_penalty_temperature': Array([0.00228509], dtype=float32), 'log_temperature': Array([-0.00092893], dtype=float32), 'mean': Array([-0.00117062, 0.00081603], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00017095, 0.0013548 ], dtype=float32), 'log_penalty_temperature': Array([-2.6652073e-05], dtype=float32), 'log_temperature': Array([0.00211281], dtype=float32), 'mean': Array([ 0.0345919, -1.0083373], dtype=float32)}\n",
"\t Slowdist: (Array([-0.66670513, -0.7562739 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 4900\n",
"\t Loss: 1.4418822526931763\n",
"\t Params: {'log_alpha': Array([-4.1483197, -1.9198583], dtype=float32), 'log_penalty_temperature': Array([2.4031074], dtype=float32), 'log_temperature': Array([0.83371115], dtype=float32), 'mean': Array([-0.7210007, -0.7459882], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.48020983, dtype=float32), 'loss_penalty_temperature': Array(-0.04381658, dtype=float32), 'loss_temperature': Array(1.0039631, dtype=float32), 'non_parametric_kl': Array(0.13607621, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00060268, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00087584, -0.00061502], dtype=float32), 'log_penalty_temperature': Array([0.0010249], dtype=float32), 'log_temperature': Array([0.00166394], dtype=float32), 'mean': Array([-0.00055742, -0.00239924], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00015559, 0.00127946], dtype=float32), 'log_penalty_temperature': Array([0.00036428], dtype=float32), 'log_temperature': Array([-0.02513746], dtype=float32), 'mean': Array([-1.0582699, -0.5433248], dtype=float32)}\n",
"\t Slowdist: (Array([-0.72044325, -0.7435889 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5000\n",
"\t Loss: 1.992458701133728\n",
"\t Params: {'log_alpha': Array([-4.2593904, -1.9901031], dtype=float32), 'log_penalty_temperature': Array([2.3384607], dtype=float32), 'log_temperature': Array([0.7984999], dtype=float32), 'mean': Array([-0.6703219, -0.7522469], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.0373766, dtype=float32), 'loss_penalty_temperature': Array(-0.058481, dtype=float32), 'loss_temperature': Array(1.0121406, dtype=float32), 'non_parametric_kl': Array(0.09948082, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00133776, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00111993, -0.00062169], dtype=float32), 'log_penalty_temperature': Array([-0.00160528], dtype=float32), 'log_temperature': Array([-0.00055077], dtype=float32), 'mean': Array([0.00032455, 0.00179996], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00013949, 0.00120312], dtype=float32), 'log_penalty_temperature': Array([-0.00030819], dtype=float32), 'log_temperature': Array([0.00035829], dtype=float32), 'mean': Array([-0.22516556, -0.7188325 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6706464 , -0.75404686], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5100\n",
"\t Loss: 1.5075807571411133\n",
"\t Params: {'log_alpha': Array([-4.3565655, -2.051661 ], dtype=float32), 'log_penalty_temperature': Array([2.3581972], dtype=float32), 'log_temperature': Array([0.8143744], dtype=float32), 'mean': Array([-0.7160674, -0.7574166], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5423977, dtype=float32), 'loss_penalty_temperature': Array(-0.04958641, dtype=float32), 'loss_temperature': Array(1.0134321, dtype=float32), 'non_parametric_kl': Array(0.11238277, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00099535, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0010489 , -0.00056574], dtype=float32), 'log_penalty_temperature': Array([0.00451593], dtype=float32), 'log_temperature': Array([0.00028257], dtype=float32), 'mean': Array([0.00164637, 0.00042726], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00012673, 0.00113942], dtype=float32), 'log_penalty_temperature': Array([4.2677448e-06], dtype=float32), 'log_temperature': Array([-0.00858098], dtype=float32), 'mean': Array([ 0.4793455, -0.9757123], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7177137 , -0.75784385], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5200\n",
"\t Loss: 1.4910800457000732\n",
"\t Params: {'log_alpha': Array([-4.4529047, -2.1006014], dtype=float32), 'log_penalty_temperature': Array([2.5718637], dtype=float32), 'log_temperature': Array([0.77714586], dtype=float32), 'mean': Array([-0.77477914, -0.72820705], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5282289, dtype=float32), 'loss_penalty_temperature': Array(-0.06596357, dtype=float32), 'loss_temperature': Array(1.0275443, dtype=float32), 'non_parametric_kl': Array(0.11360885, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00126331, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-1.0189972e-03, -6.5966335e-05], dtype=float32), 'log_penalty_temperature': Array([0.00589109], dtype=float32), 'log_temperature': Array([0.00041908], dtype=float32), 'mean': Array([0.00042323, 0.00072901], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00011522, 0.00109045], dtype=float32), 'log_penalty_temperature': Array([-0.00024456], dtype=float32), 'log_temperature': Array([-0.00932162], dtype=float32), 'mean': Array([-0.5734161 , 0.34144628], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7752024, -0.7289361], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5300\n",
"\t Loss: 1.397669792175293\n",
"\t Params: {'log_alpha': Array([-4.52204 , -2.1656787], dtype=float32), 'log_penalty_temperature': Array([2.5261047], dtype=float32), 'log_temperature': Array([0.80765057], dtype=float32), 'mean': Array([-0.7709242, -0.7483599], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.37832153, dtype=float32), 'loss_penalty_temperature': Array(-0.05737764, dtype=float32), 'loss_temperature': Array(1.0755315, dtype=float32), 'non_parametric_kl': Array(0.11095591, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00090638, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0009581 , -0.00067912], dtype=float32), 'log_penalty_temperature': Array([-0.00161199], dtype=float32), 'log_temperature': Array([0.00060382], dtype=float32), 'mean': Array([-0.00117635, 0.00168947], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.0001076 , 0.00102938], dtype=float32), 'log_penalty_temperature': Array([8.662219e-05], dtype=float32), 'log_temperature': Array([-0.00757578], dtype=float32), 'mean': Array([-0.39065963, 0.43338344], dtype=float32)}\n",
"\t Slowdist: (Array([-0.76974785, -0.7500494 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5400\n",
"\t Loss: 1.4973490238189697\n",
"\t Params: {'log_alpha': Array([-4.561755, -2.230867], dtype=float32), 'log_penalty_temperature': Array([2.5668833], dtype=float32), 'log_temperature': Array([0.83121216], dtype=float32), 'mean': Array([-0.7163919, -0.791718 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5576599, dtype=float32), 'loss_penalty_temperature': Array(-0.05525138, dtype=float32), 'loss_temperature': Array(0.99381566, dtype=float32), 'non_parametric_kl': Array(0.07585713, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00093813, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00055809, -0.00048323], dtype=float32), 'log_penalty_temperature': Array([0.00245225], dtype=float32), 'log_temperature': Array([-0.00095753], dtype=float32), 'mean': Array([-0.00405033, 0.00187152], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([0.00010341, 0.00097055], dtype=float32), 'log_penalty_temperature': Array([5.748617e-05], dtype=float32), 'log_temperature': Array([0.0168231], dtype=float32), 'mean': Array([ 0.03209221, -0.4703173 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7123416 , -0.79358953], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5500\n",
"\t Loss: 1.6289585828781128\n",
"\t Params: {'log_alpha': Array([-4.646823 , -2.2834606], dtype=float32), 'log_penalty_temperature': Array([2.2707133], dtype=float32), 'log_temperature': Array([0.80119723], dtype=float32), 'mean': Array([-0.6585817 , -0.71197295], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8141037, dtype=float32), 'loss_penalty_temperature': Array(-0.04315935, dtype=float32), 'loss_temperature': Array(0.85694736, dtype=float32), 'non_parametric_kl': Array(0.10766677, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00050616, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00099945, -0.00064157], dtype=float32), 'log_penalty_temperature': Array([-0.00820792], dtype=float32), 'log_temperature': Array([-4.615907e-05], dtype=float32), 'mean': Array([ 0.00433721, -0.00091894], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([9.510305e-05, 9.255590e-04], dtype=float32), 'log_penalty_temperature': Array([0.00044795], dtype=float32), 'log_temperature': Array([-0.00529201], dtype=float32), 'mean': Array([-0.11834199, -0.06881957], dtype=float32)}\n",
"\t Slowdist: (Array([-0.66291887, -0.711054 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5600\n",
"\t Loss: 1.5513050556182861\n",
"\t Params: {'log_alpha': Array([-4.725406, -2.350718], dtype=float32), 'log_penalty_temperature': Array([2.3009677], dtype=float32), 'log_temperature': Array([0.8199756], dtype=float32), 'mean': Array([-0.71398175, -0.7267592 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6017201, dtype=float32), 'loss_penalty_temperature': Array(-0.04678169, dtype=float32), 'loss_temperature': Array(0.9953675, dtype=float32), 'non_parametric_kl': Array(0.08396924, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00093122, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00097224, -0.00067798], dtype=float32), 'log_penalty_temperature': Array([-0.00128396], dtype=float32), 'log_temperature': Array([0.00043655], dtype=float32), 'mean': Array([ 0.00112236, -0.00028689], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([8.797650e-05, 8.706258e-04], dtype=float32), 'log_penalty_temperature': Array([6.256764e-05], dtype=float32), 'log_temperature': Array([0.01112761], dtype=float32), 'mean': Array([0.12944141, 0.39194772], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7151041, -0.7264723], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5700\n",
"\t Loss: 1.6376146078109741\n",
"\t Params: {'log_alpha': Array([-4.8011346, -2.4208188], dtype=float32), 'log_penalty_temperature': Array([2.315311], dtype=float32), 'log_temperature': Array([0.82437986], dtype=float32), 'mean': Array([-0.73218954, -0.7532774 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6772865, dtype=float32), 'loss_penalty_temperature': Array(-0.04239136, dtype=float32), 'loss_temperature': Array(1.0017858, dtype=float32), 'non_parametric_kl': Array(0.10089241, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00108327, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00091519, -0.0006933 ], dtype=float32), 'log_penalty_temperature': Array([0.00335649], dtype=float32), 'log_temperature': Array([0.00030294], dtype=float32), 'mean': Array([-0.0016165 , -0.00150966], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([8.1607934e-05, 8.1650843e-04], dtype=float32), 'log_penalty_temperature': Array([-7.5738506e-05], dtype=float32), 'log_temperature': Array([-0.0006202], dtype=float32), 'mean': Array([-0.6882034, -1.2417636], dtype=float32)}\n",
"\t Slowdist: (Array([-0.73057306, -0.75176775], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5800\n",
"\t Loss: 1.4127711057662964\n",
"\t Params: {'log_alpha': Array([-4.8933125, -2.4785693], dtype=float32), 'log_penalty_temperature': Array([2.164899], dtype=float32), 'log_temperature': Array([0.8422208], dtype=float32), 'mean': Array([-0.7134674, -0.7060472], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.38953352, dtype=float32), 'loss_penalty_temperature': Array(-0.03604721, dtype=float32), 'loss_temperature': Array(1.0584044, dtype=float32), 'non_parametric_kl': Array(0.08749542, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00094301, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00085003, -0.00036482], dtype=float32), 'log_penalty_temperature': Array([-0.00282112], dtype=float32), 'log_temperature': Array([-0.00114578], dtype=float32), 'mean': Array([-0.00070368, 0.00081293], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([7.447051e-05, 7.740031e-04], dtype=float32), 'log_penalty_temperature': Array([5.106544e-05], dtype=float32), 'log_temperature': Array([0.00874277], dtype=float32), 'mean': Array([-0.77185684, -0.13196912], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7127637, -0.7068601], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 5900\n",
"\t Loss: 1.7246118783950806\n",
"\t Params: {'log_alpha': Array([-4.9671917, -2.531075 ], dtype=float32), 'log_penalty_temperature': Array([2.3730574], dtype=float32), 'log_temperature': Array([0.8114506], dtype=float32), 'mean': Array([-0.68693644, -0.76676583], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8280809, dtype=float32), 'loss_penalty_temperature': Array(-0.05687259, dtype=float32), 'loss_temperature': Array(0.95256853, dtype=float32), 'non_parametric_kl': Array(0.07427502, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00107894, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-5.6609587e-04, 3.2759810e-05], dtype=float32), 'log_penalty_temperature': Array([0.0019888], dtype=float32), 'log_temperature': Array([-3.1137653e-05], dtype=float32), 'mean': Array([-0.00120044, -0.00079479], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([6.9184192e-05, 7.3705986e-04], dtype=float32), 'log_penalty_temperature': Array([-7.2231254e-05], dtype=float32), 'log_temperature': Array([0.01781254], dtype=float32), 'mean': Array([ 0.47884762, -0.70597494], dtype=float32)}\n",
"\t Slowdist: (Array([-0.685736 , -0.76597106], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6000\n",
"\t Loss: 1.4522042274475098\n",
"\t Params: {'log_alpha': Array([-5.030831 , -2.5920703], dtype=float32), 'log_penalty_temperature': Array([2.59474], dtype=float32), 'log_temperature': Array([0.82794267], dtype=float32), 'mean': Array([-0.72349656, -0.7665398 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.473536, dtype=float32), 'loss_penalty_temperature': Array(-0.06532797, dtype=float32), 'loss_temperature': Array(1.0432086, dtype=float32), 'non_parametric_kl': Array(0.11018831, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00099226, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00079899, -0.00069149], dtype=float32), 'log_penalty_temperature': Array([-0.00152503], dtype=float32), 'log_temperature': Array([0.00015817], dtype=float32), 'mean': Array([0.00100298, 0.00153505], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([6.4961270e-05, 6.9695315e-04], dtype=float32), 'log_penalty_temperature': Array([7.1814343e-06], dtype=float32), 'log_temperature': Array([-0.00708994], dtype=float32), 'mean': Array([-0.1582019 , 0.36184973], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7244995 , -0.76807487], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6100\n",
"\t Loss: 1.5414754152297974\n",
"\t Params: {'log_alpha': Array([-5.115249 , -2.6540563], dtype=float32), 'log_penalty_temperature': Array([2.399818], dtype=float32), 'log_temperature': Array([0.82553184], dtype=float32), 'mean': Array([-0.717288 , -0.70560765], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5505403, dtype=float32), 'loss_penalty_temperature': Array(-0.05953708, dtype=float32), 'loss_temperature': Array(1.0497319, dtype=float32), 'non_parametric_kl': Array(0.10117817, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00102248, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00071428, -0.00062646], dtype=float32), 'log_penalty_temperature': Array([-0.00776665], dtype=float32), 'log_temperature': Array([-0.00152547], dtype=float32), 'mean': Array([-0.00120432, 0.00052491], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.9728813e-05, 6.5777934e-04], dtype=float32), 'log_penalty_temperature': Array([-2.0638969e-05], dtype=float32), 'log_temperature': Array([-0.00081977], dtype=float32), 'mean': Array([0.91253436, 0.5149258 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7160837, -0.7061326], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6200\n",
"\t Loss: 1.4947232007980347\n",
"\t Params: {'log_alpha': Array([-5.1797123, -2.7059019], dtype=float32), 'log_penalty_temperature': Array([2.4292734], dtype=float32), 'log_temperature': Array([0.8072793], dtype=float32), 'mean': Array([-0.6718716 , -0.75143576], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.59869784, dtype=float32), 'loss_penalty_temperature': Array(-0.05138786, dtype=float32), 'loss_temperature': Array(0.9467101, dtype=float32), 'non_parametric_kl': Array(0.10077566, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00082121, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 0.00090725, -0.0005337 ], dtype=float32), 'log_penalty_temperature': Array([-0.00340258], dtype=float32), 'log_temperature': Array([-0.00182506], dtype=float32), 'mean': Array([9.4521209e-05, 3.2980572e-03], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.5930621e-05, 6.2657334e-04], dtype=float32), 'log_penalty_temperature': Array([0.00016435], dtype=float32), 'log_temperature': Array([-0.00053657], dtype=float32), 'mean': Array([-0.12532917, -0.41670966], dtype=float32)}\n",
"\t Slowdist: (Array([-0.67196614, -0.7547338 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6300\n",
"\t Loss: 1.7887980937957764\n",
"\t Params: {'log_alpha': Array([-5.2318544, -2.754501 ], dtype=float32), 'log_penalty_temperature': Array([2.418831], dtype=float32), 'log_temperature': Array([0.82517815], dtype=float32), 'mean': Array([-0.7146755 , -0.74177164], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8406153, dtype=float32), 'loss_penalty_temperature': Array(-0.07553318, dtype=float32), 'loss_temperature': Array(1.0230454, dtype=float32), 'non_parametric_kl': Array(0.10837192, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00162973, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00038077, -0.00040543], dtype=float32), 'log_penalty_temperature': Array([0.00409673], dtype=float32), 'log_temperature': Array([-5.9474285e-05], dtype=float32), 'mean': Array([ 0.00034368, -0.00094506], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.3172185e-05, 5.9855747e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00057807], dtype=float32), 'log_temperature': Array([-0.00582134], dtype=float32), 'mean': Array([ 0.82165694, -0.30313882], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71501917, -0.7408266 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6400\n",
"\t Loss: 1.5726470947265625\n",
"\t Params: {'log_alpha': Array([-5.293041 , -2.8149495], dtype=float32), 'log_penalty_temperature': Array([2.1844049], dtype=float32), 'log_temperature': Array([0.8231935], dtype=float32), 'mean': Array([-0.74631095, -0.72318393], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5626811, dtype=float32), 'loss_penalty_temperature': Array(-0.06305401, dtype=float32), 'loss_temperature': Array(1.0723878, dtype=float32), 'non_parametric_kl': Array(0.10864969, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00144622, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00069913, -0.00037087], dtype=float32), 'log_penalty_temperature': Array([0.00574723], dtype=float32), 'log_temperature': Array([0.00051897], dtype=float32), 'mean': Array([ 0.00128449, -0.00431655], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.004792e-05, 5.654145e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00040085], dtype=float32), 'log_temperature': Array([-0.00600978], dtype=float32), 'mean': Array([-0.23612581, 0.21673661], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7475954 , -0.71886736], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6500\n",
"\t Loss: 1.9008150100708008\n",
"\t Params: {'log_alpha': Array([-5.373104 , -2.8409452], dtype=float32), 'log_penalty_temperature': Array([2.4360569], dtype=float32), 'log_temperature': Array([0.7942898], dtype=float32), 'mean': Array([-0.73636377, -0.75631654], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.93575263, dtype=float32), 'loss_penalty_temperature': Array(-0.06932205, dtype=float32), 'loss_temperature': Array(1.0337708, dtype=float32), 'non_parametric_kl': Array(0.10582413, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0014547, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-8.1922248e-04, -4.4011304e-05], dtype=float32), 'log_penalty_temperature': Array([0.00447164], dtype=float32), 'log_temperature': Array([-0.0005567], dtype=float32), 'mean': Array([4.9123345e-03, 8.7719927e-05], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.622046e-05, 5.515356e-04], dtype=float32), 'log_penalty_temperature': Array([-0.000418], dtype=float32), 'log_temperature': Array([-0.00401206], dtype=float32), 'mean': Array([ 0.30898118, -0.79522663], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7412761, -0.7564043], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6600\n",
"\t Loss: 1.2803170680999756\n",
"\t Params: {'log_alpha': Array([-5.4262877, -2.8977537], dtype=float32), 'log_penalty_temperature': Array([2.1486056], dtype=float32), 'log_temperature': Array([0.79630524], dtype=float32), 'mean': Array([-0.7131171 , -0.70363945], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.2419528, dtype=float32), 'loss_penalty_temperature': Array(-0.04085723, dtype=float32), 'loss_temperature': Array(1.0786403, dtype=float32), 'non_parametric_kl': Array(0.10055591, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00085673, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00042227, -0.00070987], dtype=float32), 'log_penalty_temperature': Array([-0.00171244], dtype=float32), 'log_temperature': Array([0.00195025], dtype=float32), 'mean': Array([ 0.0047164 , -0.00214306], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.3819698e-05, 5.2299886e-04], dtype=float32), 'log_penalty_temperature': Array([0.00012836], dtype=float32), 'log_temperature': Array([-0.000383], dtype=float32), 'mean': Array([-0.39836013, 0.24646443], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7178335 , -0.70149636], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6700\n",
"\t Loss: 1.885292887687683\n",
"\t Params: {'log_alpha': Array([-5.412683 , -2.9501443], dtype=float32), 'log_penalty_temperature': Array([2.4990628], dtype=float32), 'log_temperature': Array([0.8147697], dtype=float32), 'mean': Array([-0.77054447, -0.7011013 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(1.0145726, dtype=float32), 'loss_penalty_temperature': Array(-0.05257157, dtype=float32), 'loss_temperature': Array(0.922737, dtype=float32), 'non_parametric_kl': Array(0.1110039, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00086312, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00011892, -0.00066721], dtype=float32), 'log_penalty_temperature': Array([0.00081695], dtype=float32), 'log_temperature': Array([0.00279186], dtype=float32), 'mean': Array([ 0.0024785 , -0.00246928], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.4403831e-05, 4.9761217e-04], dtype=float32), 'log_penalty_temperature': Array([0.00012647], dtype=float32), 'log_temperature': Array([-0.0076206], dtype=float32), 'mean': Array([-0.3887855 , 0.05873239], dtype=float32)}\n",
"\t Slowdist: (Array([-0.77302295, -0.698632 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6800\n",
"\t Loss: 1.774002194404602\n",
"\t Params: {'log_alpha': Array([-5.472293 , -3.0107193], dtype=float32), 'log_penalty_temperature': Array([2.4360096], dtype=float32), 'log_temperature': Array([0.8167053], dtype=float32), 'mean': Array([-0.7025588, -0.7106018], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8440896, dtype=float32), 'loss_penalty_temperature': Array(-0.04578009, dtype=float32), 'loss_temperature': Array(0.97516954, dtype=float32), 'non_parametric_kl': Array(0.08995473, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00145567, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00081878, -0.00054951], dtype=float32), 'log_penalty_temperature': Array([-0.00454108], dtype=float32), 'log_temperature': Array([0.00081586], dtype=float32), 'mean': Array([0.0054805 , 0.00058803], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.1874209e-05, 4.6968547e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00041913], dtype=float32), 'log_temperature': Array([0.00696489], dtype=float32), 'mean': Array([-1.0759147 , -0.04977847], dtype=float32)}\n",
"\t Slowdist: (Array([-0.70803934, -0.71118987], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 6900\n",
"\t Loss: 1.4898117780685425\n",
"\t Params: {'log_alpha': Array([-5.527492 , -3.0549557], dtype=float32), 'log_penalty_temperature': Array([2.3860369], dtype=float32), 'log_temperature': Array([0.80851775], dtype=float32), 'mean': Array([-0.71726584, -0.6667529 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.39969558, dtype=float32), 'loss_penalty_temperature': Array(-0.03352553, dtype=float32), 'loss_temperature': Array(1.1231413, dtype=float32), 'non_parametric_kl': Array(0.11114654, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00067346, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00084291, -0.00015198], dtype=float32), 'log_penalty_temperature': Array([-0.00464068], dtype=float32), 'log_temperature': Array([0.00063373], dtype=float32), 'mean': Array([ 0.00038174, -0.00111737], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.9635284e-05, 4.5010526e-04], dtype=float32), 'log_penalty_temperature': Array([0.00029911], dtype=float32), 'log_temperature': Array([-0.00770966], dtype=float32), 'mean': Array([ 0.03825758, -0.06415118], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7176476, -0.6656355], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7000\n",
"\t Loss: 1.5742888450622559\n",
"\t Params: {'log_alpha': Array([-5.608014 , -3.1067548], dtype=float32), 'log_penalty_temperature': Array([2.4144409], dtype=float32), 'log_temperature': Array([0.80083054], dtype=float32), 'mean': Array([-0.7206037 , -0.71691245], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5358858, dtype=float32), 'loss_penalty_temperature': Array(-0.04637678, dtype=float32), 'loss_temperature': Array(1.0843052, dtype=float32), 'non_parametric_kl': Array(0.11614688, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00083358, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00082606, -0.00065523], dtype=float32), 'log_penalty_temperature': Array([-0.000774], dtype=float32), 'log_temperature': Array([-0.00017836], dtype=float32), 'mean': Array([-0.00070183, -0.0015928 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.6579488e-05, 4.2856351e-04], dtype=float32), 'log_penalty_temperature': Array([0.00015272], dtype=float32), 'log_temperature': Array([-0.01114445], dtype=float32), 'mean': Array([ 0.44936985, -0.05540894], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71990186, -0.71531963], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7100\n",
"\t Loss: 1.2916717529296875\n",
"\t Params: {'log_alpha': Array([-5.6660357, -3.1634235], dtype=float32), 'log_penalty_temperature': Array([2.2946563], dtype=float32), 'log_temperature': Array([0.8255672], dtype=float32), 'mean': Array([-0.7197243 , -0.75529563], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.22477275, dtype=float32), 'loss_penalty_temperature': Array(-0.04810765, dtype=float32), 'loss_temperature': Array(1.1145579, dtype=float32), 'non_parametric_kl': Array(0.09680487, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00099747, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-7.1631995e-04, 7.2794501e-05], dtype=float32), 'log_penalty_temperature': Array([0.00103345], dtype=float32), 'log_temperature': Array([0.0007332], dtype=float32), 'mean': Array([0.00087061, 0.0021115 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.4520828e-05, 4.0562774e-04], dtype=float32), 'log_penalty_temperature': Array([2.275539e-06], dtype=float32), 'log_temperature': Array([0.00222142], dtype=float32), 'mean': Array([ 0.7630573 , -0.10629168], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7205949, -0.7574071], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7200\n",
"\t Loss: 1.4580386877059937\n",
"\t Params: {'log_alpha': Array([-5.7441382, -3.2218566], dtype=float32), 'log_penalty_temperature': Array([2.378932], dtype=float32), 'log_temperature': Array([0.8189237], dtype=float32), 'mean': Array([-0.74434465, -0.7059956 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.40997934, dtype=float32), 'loss_penalty_temperature': Array(-0.05854732, dtype=float32), 'loss_temperature': Array(1.1061833, dtype=float32), 'non_parametric_kl': Array(0.11112963, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00123919, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00065327, -0.00061406], dtype=float32), 'log_penalty_temperature': Array([0.00316541], dtype=float32), 'log_temperature': Array([0.00170328], dtype=float32), 'mean': Array([0.00033279, 0.0007527 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.1933534e-05, 3.8374111e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00021884], dtype=float32), 'log_temperature': Array([-0.00772007], dtype=float32), 'mean': Array([ 0.55960935, -0.00594892], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7446774, -0.7067483], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7300\n",
"\t Loss: 1.7015234231948853\n",
"\t Params: {'log_alpha': Array([-5.820042 , -3.2830253], dtype=float32), 'log_penalty_temperature': Array([2.315686], dtype=float32), 'log_temperature': Array([0.79167086], dtype=float32), 'mean': Array([-0.72285116, -0.7428225 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.87371314, dtype=float32), 'loss_penalty_temperature': Array(-0.0577099, dtype=float32), 'loss_temperature': Array(0.8851221, dtype=float32), 'non_parametric_kl': Array(0.10770912, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00092204, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00074769, -0.00033324], dtype=float32), 'log_penalty_temperature': Array([-0.00100749], dtype=float32), 'log_temperature': Array([0.00062979], dtype=float32), 'mean': Array([ 0.00030688, -0.00122391], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.960906e-05, 3.616975e-04], dtype=float32), 'log_penalty_temperature': Array([7.094697e-05], dtype=float32), 'log_temperature': Array([-0.00530425], dtype=float32), 'mean': Array([-0.5663503 , 0.27656466], dtype=float32)}\n",
"\t Slowdist: (Array([-0.72315806, -0.7415986 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7400\n",
"\t Loss: 1.4999785423278809\n",
"\t Params: {'log_alpha': Array([-5.8894825, -3.3467524], dtype=float32), 'log_penalty_temperature': Array([2.36722], dtype=float32), 'log_temperature': Array([0.78515303], dtype=float32), 'mean': Array([-0.7589408, -0.7282177], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5838524, dtype=float32), 'loss_penalty_temperature': Array(-0.06453513, dtype=float32), 'loss_temperature': Array(0.9802874, dtype=float32), 'non_parametric_kl': Array(0.08415233, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0010719, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00057271, -0.000639 ], dtype=float32), 'log_penalty_temperature': Array([0.00255206], dtype=float32), 'log_temperature': Array([-0.00216116], dtype=float32), 'mean': Array([-0.00060648, -0.00396204], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.7623431e-05, 3.4022666e-04], dtype=float32), 'log_penalty_temperature': Array([-6.56774e-05], dtype=float32), 'log_temperature': Array([0.01089142], dtype=float32), 'mean': Array([0.38201004, 0.12144436], dtype=float32)}\n",
"\t Slowdist: (Array([-0.75833434, -0.7242557 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7500\n",
"\t Loss: 1.6470222473144531\n",
"\t Params: {'log_alpha': Array([-5.95131 , -3.408117], dtype=float32), 'log_penalty_temperature': Array([2.4131064], dtype=float32), 'log_temperature': Array([0.79759973], dtype=float32), 'mean': Array([-0.73524773, -0.73994106], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.69751465, dtype=float32), 'loss_penalty_temperature': Array(-0.05956492, dtype=float32), 'loss_temperature': Array(1.0087206, dtype=float32), 'non_parametric_kl': Array(0.09130882, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00122495, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00020236, -0.00064364], dtype=float32), 'log_penalty_temperature': Array([0.00385431], dtype=float32), 'log_temperature': Array([0.00114514], dtype=float32), 'mean': Array([ 0.00292068, -0.00322262], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.5961965e-05, 3.2062715e-04], dtype=float32), 'log_penalty_temperature': Array([-0.0002064], dtype=float32), 'log_temperature': Array([0.00598997], dtype=float32), 'mean': Array([-0.83658046, 0.537142 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7381684, -0.7367184], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7600\n",
"\t Loss: 1.5584962368011475\n",
"\t Params: {'log_alpha': Array([-6.002321 , -3.4677868], dtype=float32), 'log_penalty_temperature': Array([2.6072166], dtype=float32), 'log_temperature': Array([0.7951684], dtype=float32), 'mean': Array([-0.7426274, -0.7121231], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6421064, dtype=float32), 'loss_penalty_temperature': Array(-0.0568133, dtype=float32), 'loss_temperature': Array(0.9728711, dtype=float32), 'non_parametric_kl': Array(0.09805866, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0008807, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00020589, -0.00066594], dtype=float32), 'log_penalty_temperature': Array([-0.00529108], dtype=float32), 'log_temperature': Array([-0.00270376], dtype=float32), 'mean': Array([-0.00066319, -0.00048837], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.4674124e-05, 3.0262361e-04], dtype=float32), 'log_penalty_temperature': Array([0.00011117], dtype=float32), 'log_temperature': Array([0.00133848], dtype=float32), 'mean': Array([ 0.05725333, -0.09692129], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7419642 , -0.71163476], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7700\n",
"\t Loss: 1.3200829029083252\n",
"\t Params: {'log_alpha': Array([-6.0511866, -3.5227795], dtype=float32), 'log_penalty_temperature': Array([2.4565923], dtype=float32), 'log_temperature': Array([0.8192495], dtype=float32), 'mean': Array([-0.7530316 , -0.74138355], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.3371985, dtype=float32), 'loss_penalty_temperature': Array(-0.04056432, dtype=float32), 'loss_temperature': Array(1.0231342, dtype=float32), 'non_parametric_kl': Array(0.1112653, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0006512, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00068625, -0.00026595], dtype=float32), 'log_penalty_temperature': Array([0.001006], dtype=float32), 'log_temperature': Array([-5.9259506e-05], dtype=float32), 'mean': Array([-0.00053252, 0.00055882], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.3511408e-05, 2.8678396e-04], dtype=float32), 'log_penalty_temperature': Array([0.00032121], dtype=float32), 'log_temperature': Array([-0.00781917], dtype=float32), 'mean': Array([-0.33585253, 0.2465067 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7524991 , -0.74194235], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7800\n",
"\t Loss: 1.6880261898040771\n",
"\t Params: {'log_alpha': Array([-6.1111546, -3.5863762], dtype=float32), 'log_penalty_temperature': Array([2.320816], dtype=float32), 'log_temperature': Array([0.76445913], dtype=float32), 'mean': Array([-0.69202816, -0.7455122 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.74765646, dtype=float32), 'loss_penalty_temperature': Array(-0.05571369, dtype=float32), 'loss_temperature': Array(0.9957879, dtype=float32), 'non_parametric_kl': Array(0.0934229, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00144247, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00073979, -0.0006664 ], dtype=float32), 'log_penalty_temperature': Array([-0.00052497], dtype=float32), 'log_temperature': Array([0.00075781], dtype=float32), 'mean': Array([ 0.00323151, -0.00155978], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.2147135e-05, 2.6969472e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00040296], dtype=float32), 'log_temperature': Array([0.0044867], dtype=float32), 'mean': Array([ 0.12955846, -0.8622572 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6952597, -0.7439524], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 7900\n",
"\t Loss: 1.4556248188018799\n",
"\t Params: {'log_alpha': Array([-6.159354 , -3.6268108], dtype=float32), 'log_penalty_temperature': Array([2.467892], dtype=float32), 'log_temperature': Array([0.79820114], dtype=float32), 'mean': Array([-0.7193236 , -0.71182907], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.48453546, dtype=float32), 'loss_penalty_temperature': Array(-0.02973436, dtype=float32), 'loss_temperature': Array(1.00054, dtype=float32), 'non_parametric_kl': Array(0.091179, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00043385, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00032238, -0.00016136], dtype=float32), 'log_penalty_temperature': Array([-0.00497566], dtype=float32), 'log_temperature': Array([-0.00061293], dtype=float32), 'mean': Array([-0.0038064 , 0.00056156], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.1098389e-05, 2.5915683e-04], dtype=float32), 'log_penalty_temperature': Array([0.00052208], dtype=float32), 'log_temperature': Array([0.00608398], dtype=float32), 'mean': Array([-0.31335032, -0.5311105 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71551716, -0.7123906 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8000\n",
"\t Loss: 1.5466573238372803\n",
"\t Params: {'log_alpha': Array([-6.1846824, -3.6833465], dtype=float32), 'log_penalty_temperature': Array([2.5652742], dtype=float32), 'log_temperature': Array([0.8040289], dtype=float32), 'mean': Array([-0.74597037, -0.731411 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.64265656, dtype=float32), 'loss_penalty_temperature': Array(-0.06187137, dtype=float32), 'loss_temperature': Array(0.96560305, dtype=float32), 'non_parametric_kl': Array(0.09311168, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00100413, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00074139, -0.00052663], dtype=float32), 'log_penalty_temperature': Array([-0.00073997], dtype=float32), 'log_temperature': Array([-0.00019906], dtype=float32), 'mean': Array([-0.00014919, 0.00407315], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([2.0580408e-05, 2.4534852e-04], dtype=float32), 'log_penalty_temperature': Array([-3.803786e-06], dtype=float32), 'log_temperature': Array([0.0047589], dtype=float32), 'mean': Array([ 0.52774394, -0.63905 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7458212, -0.7354841], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8100\n",
"\t Loss: 1.644338607788086\n",
"\t Params: {'log_alpha': Array([-6.2502975, -3.7198524], dtype=float32), 'log_penalty_temperature': Array([2.2109387], dtype=float32), 'log_temperature': Array([0.8192656], dtype=float32), 'mean': Array([-0.68971604, -0.6859378 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6432339, dtype=float32), 'loss_penalty_temperature': Array(-0.04584932, dtype=float32), 'loss_temperature': Array(1.0466952, dtype=float32), 'non_parametric_kl': Array(0.08531148, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00097228, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00045646, -0.00013141], dtype=float32), 'log_penalty_temperature': Array([-0.0033054], dtype=float32), 'log_temperature': Array([-0.00387483], dtype=float32), 'mean': Array([ 0.00022475, -0.00320961], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.9270403e-05, 2.3667022e-04], dtype=float32), 'log_penalty_temperature': Array([2.5036983e-05], dtype=float32), 'log_temperature': Array([0.01020714], dtype=float32), 'mean': Array([-0.36887613, 1.0903556 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6899408 , -0.68272823], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8200\n",
"\t Loss: 1.300633430480957\n",
"\t Params: {'log_alpha': Array([-6.3125577, -3.7704895], dtype=float32), 'log_penalty_temperature': Array([2.450081], dtype=float32), 'log_temperature': Array([0.798194], dtype=float32), 'mean': Array([-0.7118782, -0.7154638], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.21955222, dtype=float32), 'loss_penalty_temperature': Array(-0.04146672, dtype=float32), 'loss_temperature': Array(1.1223019, dtype=float32), 'non_parametric_kl': Array(0.08510669, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00086119, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00055155, -0.00062886], dtype=float32), 'log_penalty_temperature': Array([-0.00031714], dtype=float32), 'log_temperature': Array([-0.00106349], dtype=float32), 'mean': Array([0.00166994, 0.00176997], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.8111035e-05, 2.2535711e-04], dtype=float32), 'log_penalty_temperature': Array([0.00012774], dtype=float32), 'log_temperature': Array([0.01027361], dtype=float32), 'mean': Array([-0.09802828, 0.33917537], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7135481, -0.7172338], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8300\n",
"\t Loss: 1.254336953163147\n",
"\t Params: {'log_alpha': Array([-6.379435 , -3.8354983], dtype=float32), 'log_penalty_temperature': Array([2.3861678], dtype=float32), 'log_temperature': Array([0.8108102], dtype=float32), 'mean': Array([-0.71847415, -0.74447507], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.19845597, dtype=float32), 'loss_penalty_temperature': Array(-0.04984281, dtype=float32), 'loss_temperature': Array(1.1054932, dtype=float32), 'non_parametric_kl': Array(0.10232684, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00064966, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00072232, -0.00055457], dtype=float32), 'log_penalty_temperature': Array([0.00263169], dtype=float32), 'log_temperature': Array([0.00010165], dtype=float32), 'mean': Array([-0.00150492, 0.00121062], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.6944308e-05, 2.1145749e-04], dtype=float32), 'log_penalty_temperature': Array([0.00032075], dtype=float32), 'log_temperature': Array([-0.00161076], dtype=float32), 'mean': Array([0.6706293 , 0.93226856], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71696925, -0.7456857 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8400\n",
"\t Loss: 1.2810614109039307\n",
"\t Params: {'log_alpha': Array([-6.4338875, -3.8944588], dtype=float32), 'log_penalty_temperature': Array([2.6201413], dtype=float32), 'log_temperature': Array([0.8239335], dtype=float32), 'mean': Array([-0.7290428, -0.7640125], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.29151118, dtype=float32), 'loss_penalty_temperature': Array(-0.05286044, dtype=float32), 'loss_temperature': Array(1.0421929, dtype=float32), 'non_parametric_kl': Array(0.09448501, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00068635, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00070484, -0.00041329], dtype=float32), 'log_penalty_temperature': Array([0.00101875], dtype=float32), 'log_temperature': Array([0.00160235], dtype=float32), 'mean': Array([-0.00269167, 0.00605734], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.6047477e-05, 1.9956430e-04], dtype=float32), 'log_penalty_temperature': Array([0.00029227], dtype=float32), 'log_temperature': Array([0.00383141], dtype=float32), 'mean': Array([ 0.7376607, -0.5320484], dtype=float32)}\n",
"\t Slowdist: (Array([-0.72635114, -0.77006984], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8500\n",
"\t Loss: 1.6595110893249512\n",
"\t Params: {'log_alpha': Array([-6.4932213, -3.9302938], dtype=float32), 'log_penalty_temperature': Array([2.4312696], dtype=float32), 'log_temperature': Array([0.82035613], dtype=float32), 'mean': Array([-0.7693642, -0.7397693], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.7792685, dtype=float32), 'loss_penalty_temperature': Array(-0.09173838, dtype=float32), 'loss_temperature': Array(0.9717713, dtype=float32), 'non_parametric_kl': Array(0.11363997, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00168656, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00048964, -0.00061394], dtype=float32), 'log_penalty_temperature': Array([0.00164156], dtype=float32), 'log_temperature': Array([0.00329888], dtype=float32), 'mean': Array([-0.00210921, -0.00257132], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.5121163e-05, 1.9271280e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00063107], dtype=float32), 'log_temperature': Array([-0.00946086], dtype=float32), 'mean': Array([0.98418653, 0.1411263 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.76725495, -0.73719794], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8600\n",
"\t Loss: 1.5099983215332031\n",
"\t Params: {'log_alpha': Array([-6.554261 , -3.9837031], dtype=float32), 'log_penalty_temperature': Array([2.2152743], dtype=float32), 'log_temperature': Array([0.7840993], dtype=float32), 'mean': Array([-0.6889753, -0.7551133], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6030709, dtype=float32), 'loss_penalty_temperature': Array(-0.05998194, dtype=float32), 'loss_temperature': Array(0.9667107, dtype=float32), 'non_parametric_kl': Array(0.10491432, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00129076, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00061927, 0.00028182], dtype=float32), 'log_penalty_temperature': Array([0.00506928], dtype=float32), 'log_temperature': Array([-0.00349134], dtype=float32), 'mean': Array([-3.7626694e-05, -7.2202884e-04], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.42288945e-05, 1.82712727e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00026203], dtype=float32), 'log_temperature': Array([-0.00337763], dtype=float32), 'mean': Array([ 0.5339562 , -0.12747228], dtype=float32)}\n",
"\t Slowdist: (Array([-0.68893766, -0.75439125], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8700\n",
"\t Loss: 1.4586509466171265\n",
"\t Params: {'log_alpha': Array([-6.617687 , -4.0380116], dtype=float32), 'log_penalty_temperature': Array([2.127689], dtype=float32), 'log_temperature': Array([0.8533487], dtype=float32), 'mean': Array([-0.71496135, -0.6908974 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.45254654, dtype=float32), 'loss_penalty_temperature': Array(-0.05147193, dtype=float32), 'loss_temperature': Array(1.0573881, dtype=float32), 'non_parametric_kl': Array(0.09699976, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00108881, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00064497, -0.00046514], dtype=float32), 'log_penalty_temperature': Array([-0.00190995], dtype=float32), 'log_temperature': Array([-3.0729962e-06], dtype=float32), 'mean': Array([-0.00364351, 0.00065743], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.3355947e-05, 1.7334895e-04], dtype=float32), 'log_penalty_temperature': Array([-7.932965e-05], dtype=float32), 'log_temperature': Array([0.00210389], dtype=float32), 'mean': Array([0.77573216, 0.07999189], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71131784, -0.69155484], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8800\n",
"\t Loss: 1.4709903001785278\n",
"\t Params: {'log_alpha': Array([-6.6710796, -4.0624623], dtype=float32), 'log_penalty_temperature': Array([2.3390172], dtype=float32), 'log_temperature': Array([0.8178568], dtype=float32), 'mean': Array([-0.764758 , -0.71136576], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.49314305, dtype=float32), 'loss_penalty_temperature': Array(-0.05631953, dtype=float32), 'loss_temperature': Array(1.0339835, dtype=float32), 'non_parametric_kl': Array(0.08741156, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00107289, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00048483, -0.00026265], dtype=float32), 'log_penalty_temperature': Array([0.00889829], dtype=float32), 'log_temperature': Array([-0.0020625], dtype=float32), 'mean': Array([0.00157168, 0.00189556], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.2660396e-05, 1.6919912e-04], dtype=float32), 'log_penalty_temperature': Array([-6.6409884e-05], dtype=float32), 'log_temperature': Array([0.00873916], dtype=float32), 'mean': Array([0.14637814, 0.28775752], dtype=float32)}\n",
"\t Slowdist: (Array([-0.76632965, -0.7132613 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 8900\n",
"\t Loss: 1.2123603820800781\n",
"\t Params: {'log_alpha': Array([-6.7245197, -4.1150274], dtype=float32), 'log_penalty_temperature': Array([2.7688088], dtype=float32), 'log_temperature': Array([0.86401665], dtype=float32), 'mean': Array([-0.75205326, -0.8043602 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.14530157, dtype=float32), 'loss_penalty_temperature': Array(-0.05932507, dtype=float32), 'loss_temperature': Array(1.1262099, dtype=float32), 'non_parametric_kl': Array(0.10742441, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00073932, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00053743, -0.00065636], dtype=float32), 'log_penalty_temperature': Array([0.00403645], dtype=float32), 'log_temperature': Array([0.00468077], dtype=float32), 'mean': Array([ 0.00177497, -0.00227435], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.2003004e-05, 1.6073624e-04], dtype=float32), 'log_penalty_temperature': Array([0.00024523], dtype=float32), 'log_temperature': Array([-0.00521575], dtype=float32), 'mean': Array([ 0.4103798 , -0.09376205], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7538282, -0.8020859], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9000\n",
"\t Loss: 1.483298659324646\n",
"\t Params: {'log_alpha': Array([-6.7871213, -4.1761928], dtype=float32), 'log_penalty_temperature': Array([2.4074605], dtype=float32), 'log_temperature': Array([0.8349451], dtype=float32), 'mean': Array([-0.7180599 , -0.71657795], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.48389745, dtype=float32), 'loss_penalty_temperature': Array(-0.05548671, dtype=float32), 'loss_temperature': Array(1.0547241, dtype=float32), 'non_parametric_kl': Array(0.09472062, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00132106, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00066804, -0.00059751], dtype=float32), 'log_penalty_temperature': Array([-0.00075442], dtype=float32), 'log_temperature': Array([0.00024593], dtype=float32), 'mean': Array([ 0.00082538, -0.00127481], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.12769285e-05, 1.51334985e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00029448], dtype=float32), 'log_temperature': Array([0.00368149], dtype=float32), 'mean': Array([0.30883592, 0.963312 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7188853, -0.7153031], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9100\n",
"\t Loss: 1.64036226272583\n",
"\t Params: {'log_alpha': Array([-6.8548226, -4.2269287], dtype=float32), 'log_penalty_temperature': Array([2.4021707], dtype=float32), 'log_temperature': Array([0.82061005], dtype=float32), 'mean': Array([-0.7761994 , -0.73832256], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6803206, dtype=float32), 'loss_penalty_temperature': Array(-0.07083843, dtype=float32), 'loss_temperature': Array(1.0307246, dtype=float32), 'non_parametric_kl': Array(0.09436078, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00114683, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00053639, -0.00029089], dtype=float32), 'log_penalty_temperature': Array([0.00376842], dtype=float32), 'log_temperature': Array([-0.00066724], dtype=float32), 'mean': Array([-0.00276359, 0.00121997], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([1.0538128e-05, 1.4391266e-04], dtype=float32), 'log_penalty_temperature': Array([-0.0001346], dtype=float32), 'log_temperature': Array([0.00391645], dtype=float32), 'mean': Array([-0.21734789, -0.14037757], dtype=float32)}\n",
"\t Slowdist: (Array([-0.77343583, -0.73954254], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9200\n",
"\t Loss: 1.680344581604004\n",
"\t Params: {'log_alpha': Array([-6.9165134, -4.2848186], dtype=float32), 'log_penalty_temperature': Array([2.5083637], dtype=float32), 'log_temperature': Array([0.79925704], dtype=float32), 'mean': Array([-0.7059717, -0.7303236], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.72053707, dtype=float32), 'loss_penalty_temperature': Array(-0.06434027, dtype=float32), 'loss_temperature': Array(1.0240009, dtype=float32), 'non_parametric_kl': Array(0.08820509, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00110975, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00044127, -0.00067208], dtype=float32), 'log_penalty_temperature': Array([-0.00335366], dtype=float32), 'log_temperature': Array([6.118088e-05], dtype=float32), 'mean': Array([ 0.00081079, -0.00225375], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([9.9073468e-06, 1.3597924e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00010161], dtype=float32), 'log_temperature': Array([0.00813611], dtype=float32), 'mean': Array([0.3065004, 0.5817537], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7067825 , -0.72806984], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9300\n",
"\t Loss: 1.5738105773925781\n",
"\t Params: {'log_alpha': Array([-6.9432836, -4.3294597], dtype=float32), 'log_penalty_temperature': Array([2.685096], dtype=float32), 'log_temperature': Array([0.84594053], dtype=float32), 'mean': Array([-0.72987694, -0.7659529 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.61415696, dtype=float32), 'loss_penalty_temperature': Array(-0.05943195, dtype=float32), 'loss_temperature': Array(1.018945, dtype=float32), 'non_parametric_kl': Array(0.10570996, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00088945, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-2.8309138e-05, -6.3326216e-04], dtype=float32), 'log_penalty_temperature': Array([0.00182869], dtype=float32), 'log_temperature': Array([8.4106614e-05], dtype=float32), 'mean': Array([ 0.00406933, -0.00077955], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([9.641921e-06, 1.301148e-04], dtype=float32), 'log_penalty_temperature': Array([0.00010346], dtype=float32), 'log_temperature': Array([-0.00399533], dtype=float32), 'mean': Array([-0.63308007, -0.43187022], dtype=float32)}\n",
"\t Slowdist: (Array([-0.73394626, -0.7651733 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9400\n",
"\t Loss: 1.6582413911819458\n",
"\t Params: {'log_alpha': Array([-6.990595 , -4.3923073], dtype=float32), 'log_penalty_temperature': Array([2.31352], dtype=float32), 'log_temperature': Array([0.8361588], dtype=float32), 'mean': Array([-0.7415044 , -0.71169466], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.72477657, dtype=float32), 'loss_penalty_temperature': Array(-0.06234156, dtype=float32), 'loss_temperature': Array(0.9956742, dtype=float32), 'non_parametric_kl': Array(0.10696869, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00124795, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 5.6374956e-05, -5.9616478e-04], dtype=float32), 'log_penalty_temperature': Array([0.00147561], dtype=float32), 'log_temperature': Array([0.00157764], dtype=float32), 'mean': Array([0.00012554, 0.0024677 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([9.196007e-06, 1.222815e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00022555], dtype=float32), 'log_temperature': Array([-0.00485927], dtype=float32), 'mean': Array([-0.20784478, 0.15205158], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7416299 , -0.71416235], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9500\n",
"\t Loss: 1.720261573791504\n",
"\t Params: {'log_alpha': Array([-7.0277452, -4.4534626], dtype=float32), 'log_penalty_temperature': Array([2.1224782], dtype=float32), 'log_temperature': Array([0.8692089], dtype=float32), 'mean': Array([-0.68217087, -0.70849115], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.80494493, dtype=float32), 'loss_penalty_temperature': Array(-0.04521666, dtype=float32), 'loss_temperature': Array(0.96040857, dtype=float32), 'non_parametric_kl': Array(0.08130201, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00141234, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00062856, -0.00059432], dtype=float32), 'log_penalty_temperature': Array([-0.0038217], dtype=float32), 'log_temperature': Array([0.00059405], dtype=float32), 'mean': Array([0.00095991, 0.00050033], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([8.8670013e-06, 1.1511071e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00036836], dtype=float32), 'log_temperature': Array([0.01317204], dtype=float32), 'mean': Array([ 0.6974995, -1.1129929], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6831308 , -0.70899147], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9600\n",
"\t Loss: 1.580576777458191\n",
"\t Params: {'log_alpha': Array([-7.088562, -4.509091], dtype=float32), 'log_penalty_temperature': Array([2.1785932], dtype=float32), 'log_temperature': Array([0.81321967], dtype=float32), 'mean': Array([-0.7486918, -0.7011437], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.67268807, dtype=float32), 'loss_penalty_temperature': Array(-0.053737, dtype=float32), 'loss_temperature': Array(0.96150774, dtype=float32), 'non_parametric_kl': Array(0.10017306, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00128587, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00017818, -0.00062461], dtype=float32), 'log_penalty_temperature': Array([0.00729288], dtype=float32), 'log_temperature': Array([-0.00034926], dtype=float32), 'mean': Array([-0.00025096, 0.00026857], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([8.3404921e-06, 1.0895324e-04], dtype=float32), 'log_penalty_temperature': Array([-0.00025661], dtype=float32), 'log_temperature': Array([-0.00011995], dtype=float32), 'mean': Array([0.0833376 , 0.19689672], dtype=float32)}\n",
"\t Slowdist: (Array([-0.74844086, -0.70141226], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9700\n",
"\t Loss: 1.4982753992080688\n",
"\t Params: {'log_alpha': Array([-7.139118 , -4.5500464], dtype=float32), 'log_penalty_temperature': Array([2.3934414], dtype=float32), 'log_temperature': Array([0.846815], dtype=float32), 'mean': Array([-0.7449248, -0.6933736], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6094734, dtype=float32), 'loss_penalty_temperature': Array(-0.05294891, dtype=float32), 'loss_temperature': Array(0.9416378, dtype=float32), 'non_parametric_kl': Array(0.08784834, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00100351, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00058205, -0.00041577], dtype=float32), 'log_penalty_temperature': Array([-0.00045694], dtype=float32), 'log_temperature': Array([-0.00081286], dtype=float32), 'mean': Array([-0.00161129, -0.00199378], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([7.9328393e-06, 1.0460527e-04], dtype=float32), 'log_penalty_temperature': Array([-3.183391e-06], dtype=float32), 'log_temperature': Array([0.00850711], dtype=float32), 'mean': Array([0.800719 , 0.7054167], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7433135 , -0.69137985], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9800\n",
"\t Loss: 1.5339332818984985\n",
"\t Params: {'log_alpha': Array([-7.2016044, -4.6013446], dtype=float32), 'log_penalty_temperature': Array([2.3364275], dtype=float32), 'log_temperature': Array([0.8454283], dtype=float32), 'mean': Array([-0.74793625, -0.71818143], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.6751684, dtype=float32), 'loss_penalty_temperature': Array(-0.05625369, dtype=float32), 'loss_temperature': Array(0.9149112, dtype=float32), 'non_parametric_kl': Array(0.09596737, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00088895, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00059528, -0.00055599], dtype=float32), 'log_penalty_temperature': Array([-0.00021562], dtype=float32), 'log_temperature': Array([-0.00226789], dtype=float32), 'mean': Array([-0.00236039, -0.00492804], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([7.4527666e-06, 9.9440324e-05], dtype=float32), 'log_penalty_temperature': Array([0.0001012], dtype=float32), 'log_temperature': Array([0.00282315], dtype=float32), 'mean': Array([-0.6523447 , 0.70174015], dtype=float32)}\n",
"\t Slowdist: (Array([-0.74557585, -0.7132534 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 9900\n",
"\t Loss: 1.5415018796920776\n",
"\t Params: {'log_alpha': Array([-7.2537313, -4.6229324], dtype=float32), 'log_penalty_temperature': Array([2.4425437], dtype=float32), 'log_temperature': Array([0.7956085], dtype=float32), 'mean': Array([-0.6918645 , -0.77884036], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.56718874, dtype=float32), 'loss_penalty_temperature': Array(-0.05842334, dtype=float32), 'loss_temperature': Array(1.0326318, dtype=float32), 'non_parametric_kl': Array(0.10385121, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0010557, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([ 4.584762e-05, -6.738090e-05], dtype=float32), 'log_penalty_temperature': Array([-0.00258085], dtype=float32), 'log_temperature': Array([0.00067314], dtype=float32), 'mean': Array([ 0.00150003, -0.00339614], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([7.069970e-06, 9.729022e-05], dtype=float32), 'log_penalty_temperature': Array([-5.126626e-05], dtype=float32), 'log_temperature': Array([-0.00265305], dtype=float32), 'mean': Array([0.08787138, 0.15549716], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6933645, -0.7754442], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10000\n",
"\t Loss: 1.5770025253295898\n",
"\t Params: {'log_alpha': Array([-7.306034, -4.682281], dtype=float32), 'log_penalty_temperature': Array([2.5138862], dtype=float32), 'log_temperature': Array([0.79651], dtype=float32), 'mean': Array([-0.7482572 , -0.74499685], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.66734385, dtype=float32), 'loss_penalty_temperature': Array(-0.05808659, dtype=float32), 'loss_temperature': Array(0.96764624, dtype=float32), 'non_parametric_kl': Array(0.09782319, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00061442, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0004379 , -0.00063469], dtype=float32), 'log_penalty_temperature': Array([0.00521892], dtype=float32), 'log_temperature': Array([-0.00043414], dtype=float32), 'mean': Array([-0.00049985, 0.003014 ], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([6.713179e-06, 9.178720e-05], dtype=float32), 'log_penalty_temperature': Array([0.0003565], dtype=float32), 'log_temperature': Array([0.00150051], dtype=float32), 'mean': Array([ 0.01789659, -0.36625358], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7477574, -0.7480109], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10100\n",
"\t Loss: 1.7157970666885376\n",
"\t Params: {'log_alpha': Array([-7.369187, -4.743017], dtype=float32), 'log_penalty_temperature': Array([2.5081284], dtype=float32), 'log_temperature': Array([0.7924902], dtype=float32), 'mean': Array([-0.7613999, -0.7031672], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.79663074, dtype=float32), 'loss_penalty_temperature': Array(-0.05210591, dtype=float32), 'loss_temperature': Array(0.97117907, dtype=float32), 'non_parametric_kl': Array(0.09005186, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00101173, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00067674, -0.0006044 ], dtype=float32), 'log_penalty_temperature': Array([-0.00119803], dtype=float32), 'log_temperature': Array([-0.00182492], dtype=float32), 'mean': Array([-0.00061225, 0.00356821], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([6.3040975e-06, 8.6422486e-05], dtype=float32), 'log_penalty_temperature': Array([-1.0841908e-05], dtype=float32), 'log_temperature': Array([0.00685192], dtype=float32), 'mean': Array([ 0.23789835, -1.3190596 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.76078767, -0.70673543], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10200\n",
"\t Loss: 1.4849193096160889\n",
"\t Params: {'log_alpha': Array([-7.431063, -4.797586], dtype=float32), 'log_penalty_temperature': Array([2.2638736], dtype=float32), 'log_temperature': Array([0.8208819], dtype=float32), 'mean': Array([-0.7335963, -0.6946234], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.49297315, dtype=float32), 'loss_penalty_temperature': Array(-0.0382915, dtype=float32), 'loss_temperature': Array(1.0301496, dtype=float32), 'non_parametric_kl': Array(0.09462324, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00094949, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00065 , -0.00045567], dtype=float32), 'log_penalty_temperature': Array([0.000584], dtype=float32), 'log_temperature': Array([-0.00317616], dtype=float32), 'mean': Array([ 0.00021808, -0.00106485], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.925911e-06, 8.185837e-05], dtype=float32), 'log_penalty_temperature': Array([4.574404e-05], dtype=float32), 'log_temperature': Array([0.00373735], dtype=float32), 'mean': Array([ 0.09377901, -0.61949587], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7338144, -0.6935586], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10300\n",
"\t Loss: 1.3528486490249634\n",
"\t Params: {'log_alpha': Array([-7.490832, -4.854326], dtype=float32), 'log_penalty_temperature': Array([2.4796612], dtype=float32), 'log_temperature': Array([0.78640085], dtype=float32), 'mean': Array([-0.7115039 , -0.73749083], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.351404, dtype=float32), 'loss_penalty_temperature': Array(-0.03735584, dtype=float32), 'loss_temperature': Array(1.038717, dtype=float32), 'non_parametric_kl': Array(0.08400993, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00069372, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00058072, -0.00066128], dtype=float32), 'log_penalty_temperature': Array([-0.00259684], dtype=float32), 'log_temperature': Array([0.00199654], dtype=float32), 'mean': Array([-0.00015173, -0.00148585], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.5819096e-06, 7.7393772e-05], dtype=float32), 'log_penalty_temperature': Array([0.00028269], dtype=float32), 'log_temperature': Array([0.01097923], dtype=float32), 'mean': Array([-0.04627372, -0.5642295 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71135217, -0.736005 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10400\n",
"\t Loss: 1.7393016815185547\n",
"\t Params: {'log_alpha': Array([-7.549619 , -4.8887577], dtype=float32), 'log_penalty_temperature': Array([2.255406], dtype=float32), 'log_temperature': Array([0.8318192], dtype=float32), 'mean': Array([-0.7179575 , -0.69254744], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8620701, dtype=float32), 'loss_penalty_temperature': Array(-0.04583537, dtype=float32), 'loss_temperature': Array(0.9229867, dtype=float32), 'non_parametric_kl': Array(0.09209978, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0009287, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00056696, -0.00024879], dtype=float32), 'log_penalty_temperature': Array([-0.00392635], dtype=float32), 'log_temperature': Array([-0.00030222], dtype=float32), 'mean': Array([-0.00167185, 0.00267832], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.2633181e-06, 7.4763295e-05], dtype=float32), 'log_penalty_temperature': Array([6.452342e-05], dtype=float32), 'log_temperature': Array([0.00550494], dtype=float32), 'mean': Array([ 0.38253748, -0.31613526], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71628565, -0.6952258 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10500\n",
"\t Loss: 1.3801509141921997\n",
"\t Params: {'log_alpha': Array([-7.5987926, -4.9345603], dtype=float32), 'log_penalty_temperature': Array([2.3930993], dtype=float32), 'log_temperature': Array([0.76199573], dtype=float32), 'mean': Array([-0.7192682 , -0.71724427], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.3493977, dtype=float32), 'loss_penalty_temperature': Array(-0.04379991, dtype=float32), 'loss_temperature': Array(1.0744764, dtype=float32), 'non_parametric_kl': Array(0.11628091, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00083258, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00055628, -0.00056655], dtype=float32), 'log_penalty_temperature': Array([-0.0055849], dtype=float32), 'log_temperature': Array([0.00127752], dtype=float32), 'mean': Array([0.00212015, 0.00129195], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([5.0108379e-06, 7.1462615e-05], dtype=float32), 'log_penalty_temperature': Array([0.0001535], dtype=float32), 'log_temperature': Array([-0.01109564], dtype=float32), 'mean': Array([ 0.4212528, -0.1260084], dtype=float32)}\n",
"\t Slowdist: (Array([-0.72138834, -0.7185362 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10600\n",
"\t Loss: 1.560949444770813\n",
"\t Params: {'log_alpha': Array([-7.657649 , -4.9954553], dtype=float32), 'log_penalty_temperature': Array([2.396031], dtype=float32), 'log_temperature': Array([0.8156956], dtype=float32), 'mean': Array([-0.74921477, -0.7529623 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5505645, dtype=float32), 'loss_penalty_temperature': Array(-0.05090782, dtype=float32), 'loss_temperature': Array(1.0612205, dtype=float32), 'non_parametric_kl': Array(0.09356704, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00096626, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00055292, -0.0006371 ], dtype=float32), 'log_penalty_temperature': Array([0.0037451], dtype=float32), 'log_temperature': Array([0.00182007], dtype=float32), 'mean': Array([-4.9280301e-05, 2.1102752e-03], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.724548e-06, 6.727388e-05], dtype=float32), 'log_penalty_temperature': Array([3.0840307e-05], dtype=float32), 'log_temperature': Array([0.00445748], dtype=float32), 'mean': Array([-0.40488228, -0.5921927 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7491655, -0.7550726], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10700\n",
"\t Loss: 1.5048987865447998\n",
"\t Params: {'log_alpha': Array([-7.7113504, -5.0450945], dtype=float32), 'log_penalty_temperature': Array([2.5073202], dtype=float32), 'log_temperature': Array([0.81161827], dtype=float32), 'mean': Array([-0.76921177, -0.6990257 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5496546, dtype=float32), 'loss_penalty_temperature': Array(-0.06646447, dtype=float32), 'loss_temperature': Array(1.02164, dtype=float32), 'non_parametric_kl': Array(0.1086131, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00115493, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00045965, -0.00041754], dtype=float32), 'log_penalty_temperature': Array([0.00015245], dtype=float32), 'log_temperature': Array([-0.00054931], dtype=float32), 'mean': Array([-0.0041294 , -0.00115229], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.4772178e-06, 6.4022883e-05], dtype=float32), 'log_penalty_temperature': Array([-0.00014329], dtype=float32), 'log_temperature': Array([-0.00596508], dtype=float32), 'mean': Array([ 0.3311998 , -0.13006158], dtype=float32)}\n",
"\t Slowdist: (Array([-0.76508236, -0.6978734 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10800\n",
"\t Loss: 1.4716602563858032\n",
"\t Params: {'log_alpha': Array([-7.769665, -5.103329], dtype=float32), 'log_penalty_temperature': Array([2.4047067], dtype=float32), 'log_temperature': Array([0.81402534], dtype=float32), 'mean': Array([-0.6878359, -0.7346125], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5118525, dtype=float32), 'loss_penalty_temperature': Array(-0.04776013, dtype=float32), 'loss_temperature': Array(1.007503, dtype=float32), 'non_parametric_kl': Array(0.09733938, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00073205, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00058256, -0.00056586], dtype=float32), 'log_penalty_temperature': Array([-0.00229245], dtype=float32), 'log_temperature': Array([-0.001663], dtype=float32), 'mean': Array([ 0.00640505, -0.00275819], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.2242250e-06, 6.0431812e-05], dtype=float32), 'log_penalty_temperature': Array([0.00024579], dtype=float32), 'log_temperature': Array([0.00184474], dtype=float32), 'mean': Array([-0.13967359, 1.0706575 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6942409, -0.7318543], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 10900\n",
"\t Loss: 1.683466911315918\n",
"\t Params: {'log_alpha': Array([-7.816053, -5.161979], dtype=float32), 'log_penalty_temperature': Array([2.5315092], dtype=float32), 'log_temperature': Array([0.83415514], dtype=float32), 'mean': Array([-0.79112667, -0.7410353 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.8108499, dtype=float32), 'loss_penalty_temperature': Array(-0.07805103, dtype=float32), 'loss_temperature': Array(0.9506069, dtype=float32), 'non_parametric_kl': Array(0.1100352, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0013178, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00052684, -0.00064512], dtype=float32), 'log_penalty_temperature': Array([0.00641698], dtype=float32), 'log_temperature': Array([0.0013347], dtype=float32), 'mean': Array([-0.00061232, -0.00372636], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([4.0325986e-06, 5.7013527e-05], dtype=float32), 'log_penalty_temperature': Array([-0.00029423], dtype=float32), 'log_temperature': Array([-0.00699415], dtype=float32), 'mean': Array([0.35537305, 0.43576503], dtype=float32)}\n",
"\t Slowdist: (Array([-0.79051435, -0.7373089 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 11000\n",
"\t Loss: 1.549668550491333\n",
"\t Params: {'log_alpha': Array([-7.8696923, -5.2213755], dtype=float32), 'log_penalty_temperature': Array([2.6399543], dtype=float32), 'log_temperature': Array([0.8591374], dtype=float32), 'mean': Array([-0.6902345 , -0.74372137], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.62358207, dtype=float32), 'loss_penalty_temperature': Array(-0.0544759, dtype=float32), 'loss_temperature': Array(0.98050463, dtype=float32), 'non_parametric_kl': Array(0.10318165, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00083214, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00039959, -0.00022339], dtype=float32), 'log_penalty_temperature': Array([-0.00267324], dtype=float32), 'log_temperature': Array([0.00010929], dtype=float32), 'mean': Array([0.00111121, 0.00036142], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.8215844e-06, 5.3720880e-05], dtype=float32), 'log_penalty_temperature': Array([0.00015671], dtype=float32), 'log_temperature': Array([-0.00223494], dtype=float32), 'mean': Array([-0.30036518, 0.5787225 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6913457, -0.7440828], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 11100\n",
"\t Loss: 1.4529552459716797\n",
"\t Params: {'log_alpha': Array([-7.919013, -5.27219 ], dtype=float32), 'log_penalty_temperature': Array([2.4066203], dtype=float32), 'log_temperature': Array([0.818809], dtype=float32), 'mean': Array([-0.7377936, -0.7118941], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.5474588, dtype=float32), 'loss_penalty_temperature': Array(-0.052452, dtype=float32), 'loss_temperature': Array(0.9578936, dtype=float32), 'non_parametric_kl': Array(0.08880925, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00089152, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0004337 , -0.00057554], dtype=float32), 'log_penalty_temperature': Array([-0.00010892], dtype=float32), 'log_temperature': Array([0.00235944], dtype=float32), 'mean': Array([ 0.00046856, -0.00016646], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.6378665e-06, 5.1090759e-05], dtype=float32), 'log_penalty_temperature': Array([9.947565e-05], dtype=float32), 'log_temperature': Array([0.00776063], dtype=float32), 'mean': Array([ 0.6748087, -0.9661611], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7382622, -0.7117276], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 11200\n",
"\t Loss: 1.5104057788848877\n",
"\t Params: {'log_alpha': Array([-7.97489 , -5.3144364], dtype=float32), 'log_penalty_temperature': Array([2.4928787], dtype=float32), 'log_temperature': Array([0.78362966], dtype=float32), 'mean': Array([-0.76717985, -0.721031 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.64846045, dtype=float32), 'loss_penalty_temperature': Array(-0.05228491, dtype=float32), 'loss_temperature': Array(0.91417766, dtype=float32), 'non_parametric_kl': Array(0.095369, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00062228, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00054886, -0.0005598 ], dtype=float32), 'log_penalty_temperature': Array([-0.00298592], dtype=float32), 'log_temperature': Array([0.00060433], dtype=float32), 'mean': Array([-0.00153177, -0.00085559], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.4406305e-06, 4.8986894e-05], dtype=float32), 'log_penalty_temperature': Array([0.00034901], dtype=float32), 'log_temperature': Array([0.0031785], dtype=float32), 'mean': Array([-0.36911818, 0.10186333], dtype=float32)}\n",
"\t Slowdist: (Array([-0.76564807, -0.72017545], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 11300\n",
"\t Loss: 1.339586615562439\n",
"\t Params: {'log_alpha': Array([-8.004655, -5.353729], dtype=float32), 'log_penalty_temperature': Array([2.4383707], dtype=float32), 'log_temperature': Array([0.815633], dtype=float32), 'mean': Array([-0.72353786, -0.74400485], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.35903665, dtype=float32), 'loss_penalty_temperature': Array(-0.0507497, dtype=float32), 'loss_temperature': Array(1.0312492, dtype=float32), 'non_parametric_kl': Array(0.08305493, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00104102, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.0004072, -0.0005657], dtype=float32), 'log_penalty_temperature': Array([0.00242156], dtype=float32), 'log_temperature': Array([0.000774], dtype=float32), 'mean': Array([0.00206211, 0.00146392], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.3392926e-06, 4.7108573e-05], dtype=float32), 'log_penalty_temperature': Array([-3.7716665e-05], dtype=float32), 'log_temperature': Array([0.01174544], dtype=float32), 'mean': Array([0.1380593 , 0.08598936], dtype=float32)}\n",
"\t Slowdist: (Array([-0.72559994, -0.7454688 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n",
"Iteration 11400\n",
"\t Loss: 1.4412344694137573\n",
"\t Params: {'log_alpha': Array([-8.064666 , -5.4154315], dtype=float32), 'log_penalty_temperature': Array([2.4510052], dtype=float32), 'log_temperature': Array([0.8200129], dtype=float32), 'mean': Array([-0.71201926, -0.72406536], dtype=float32)}\n",
"\t Metrics: {'loss_dist': Array(0.4710958, dtype=float32), 'loss_penalty_temperature': Array(-0.04272904, dtype=float32), 'loss_temperature': Array(1.0128201, dtype=float32), 'non_parametric_kl': Array(0.10607076, dtype=float32), 'parametric_kl': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00061154, dtype=float32)}\n",
"\t Updates: {'log_alpha': Array([-0.00042624, -0.00060642], dtype=float32), 'log_penalty_temperature': Array([-0.00530804], dtype=float32), 'log_temperature': Array([-0.0014828], dtype=float32), 'mean': Array([ 0.0025397 , -0.00143018], dtype=float32)}\n",
"\t Grad: {'log_alpha': Array([3.1449126e-06, 4.4304001e-05], dtype=float32), 'log_penalty_temperature': Array([0.00035781], dtype=float32), 'log_temperature': Array([-0.00421638], dtype=float32), 'mean': Array([ 0.0235918, -0.4316637], dtype=float32)}\n",
"\t Slowdist: (Array([-0.71455896, -0.7226352 ], dtype=float32), Array([0.3, 0.3], dtype=float32))\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/ale/awake/embodied/agents/mpo/test_kl.ipynb Cell 6\u001b[0m line \u001b[0;36m1\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=109'>110</a>\u001b[0m (loss, metrics), grads \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mvalue_and_grad(loss_fn, has_aux\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)(params, key, slowdist)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=110'>111</a>\u001b[0m updates, opt_state \u001b[39m=\u001b[39m optimizer\u001b[39m.\u001b[39mupdate(grads, opt_state)\n\u001b[0;32m--> <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=111'>112</a>\u001b[0m params \u001b[39m=\u001b[39m optax\u001b[39m.\u001b[39;49mapply_updates(params, updates)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=112'>113</a>\u001b[0m means\u001b[39m.\u001b[39mappend(params[\u001b[39m'\u001b[39m\u001b[39mmean\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bdgx5/home/ale/awake/embodied/agents/mpo/test_kl.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=113'>114</a>\u001b[0m \u001b[39mif\u001b[39;00m i \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/optax/_src/update.py:42\u001b[0m, in \u001b[0;36mapply_updates\u001b[0;34m(params, updates)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mapply_updates\u001b[39m(params: base\u001b[39m.\u001b[39mParams, updates: base\u001b[39m.\u001b[39mUpdates) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m base\u001b[39m.\u001b[39mParams:\n\u001b[1;32m 25\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Applies an update to the corresponding parameters.\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \n\u001b[1;32m 27\u001b[0m \u001b[39m This is a utility functions that applies an update to a set of parameters, and\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[39m Updated parameters, with same structure, shape and type as `params`.\u001b[39;00m\n\u001b[1;32m 41\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 42\u001b[0m \u001b[39mreturn\u001b[39;00m jax\u001b[39m.\u001b[39;49mtree_util\u001b[39m.\u001b[39;49mtree_map(\n\u001b[1;32m 43\u001b[0m \u001b[39mlambda\u001b[39;49;00m p, u: jnp\u001b[39m.\u001b[39;49masarray(p \u001b[39m+\u001b[39;49m u)\u001b[39m.\u001b[39;49mastype(jnp\u001b[39m.\u001b[39;49masarray(p)\u001b[39m.\u001b[39;49mdtype),\n\u001b[1;32m 44\u001b[0m params, updates)\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/tree_util.py:244\u001b[0m, in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf, *rest)\u001b[0m\n\u001b[1;32m 242\u001b[0m leaves, treedef \u001b[39m=\u001b[39m tree_flatten(tree, is_leaf)\n\u001b[1;32m 243\u001b[0m all_leaves \u001b[39m=\u001b[39m [leaves] \u001b[39m+\u001b[39m [treedef\u001b[39m.\u001b[39mflatten_up_to(r) \u001b[39mfor\u001b[39;00m r \u001b[39min\u001b[39;00m rest]\n\u001b[0;32m--> 244\u001b[0m \u001b[39mreturn\u001b[39;00m treedef\u001b[39m.\u001b[39;49munflatten(f(\u001b[39m*\u001b[39;49mxs) \u001b[39mfor\u001b[39;49;00m xs \u001b[39min\u001b[39;49;00m \u001b[39mzip\u001b[39;49m(\u001b[39m*\u001b[39;49mall_leaves))\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/tree_util.py:244\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 242\u001b[0m leaves, treedef \u001b[39m=\u001b[39m tree_flatten(tree, is_leaf)\n\u001b[1;32m 243\u001b[0m all_leaves \u001b[39m=\u001b[39m [leaves] \u001b[39m+\u001b[39m [treedef\u001b[39m.\u001b[39mflatten_up_to(r) \u001b[39mfor\u001b[39;00m r \u001b[39min\u001b[39;00m rest]\n\u001b[0;32m--> 244\u001b[0m \u001b[39mreturn\u001b[39;00m treedef\u001b[39m.\u001b[39munflatten(f(\u001b[39m*\u001b[39;49mxs) \u001b[39mfor\u001b[39;00m xs \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39m*\u001b[39mall_leaves))\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/optax/_src/update.py:43\u001b[0m, in \u001b[0;36mapply_updates.<locals>.<lambda>\u001b[0;34m(p, u)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mapply_updates\u001b[39m(params: base\u001b[39m.\u001b[39mParams, updates: base\u001b[39m.\u001b[39mUpdates) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m base\u001b[39m.\u001b[39mParams:\n\u001b[1;32m 25\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Applies an update to the corresponding parameters.\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \n\u001b[1;32m 27\u001b[0m \u001b[39m This is a utility functions that applies an update to a set of parameters, and\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[39m Updated parameters, with same structure, shape and type as `params`.\u001b[39;00m\n\u001b[1;32m 41\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 42\u001b[0m \u001b[39mreturn\u001b[39;00m jax\u001b[39m.\u001b[39mtree_util\u001b[39m.\u001b[39mtree_map(\n\u001b[0;32m---> 43\u001b[0m \u001b[39mlambda\u001b[39;00m p, u: jnp\u001b[39m.\u001b[39masarray(p \u001b[39m+\u001b[39;49m u)\u001b[39m.\u001b[39mastype(jnp\u001b[39m.\u001b[39masarray(p)\u001b[39m.\u001b[39mdtype),\n\u001b[1;32m 44\u001b[0m params, updates)\n",
"File \u001b[0;32m~/miniconda3/envs/awake/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:256\u001b[0m, in \u001b[0;36m_defer_to_unrecognized_arg.<locals>.deferring_binary_op\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 254\u001b[0m args \u001b[39m=\u001b[39m (other, \u001b[39mself\u001b[39m) \u001b[39mif\u001b[39;00m swap \u001b[39melse\u001b[39;00m (\u001b[39mself\u001b[39m, other)\n\u001b[1;32m 255\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(other, _accepted_binop_types):\n\u001b[0;32m--> 256\u001b[0m \u001b[39mreturn\u001b[39;00m binary_op(\u001b[39m*\u001b[39;49margs)\n\u001b[1;32m 257\u001b[0m \u001b[39m# Note: don't use isinstance here, because we don't want to raise for\u001b[39;00m\n\u001b[1;32m 258\u001b[0m \u001b[39m# subclasses, e.g. NamedTuple objects that may override operators.\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mtype\u001b[39m(other) \u001b[39min\u001b[39;00m _rejected_binop_types:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"optimizer = optax.adam(1e-2)\n",
"key = jax.random.PRNGKey(41)\n",
"target = -0.75\n",
"n_act_samples = 100\n",
"std = 0.3\n",
"slowdist_update_freq = 25\n",
"use_objective_constraint = True\n",
"use_slowdist_loss = True\n",
"\n",
"@jax.jit\n",
"def loss_fn(params, key, slowdist):\n",
" metrics = {}\n",
" mean = params['mean']\n",
" temperature = jax.nn.softplus(params['log_temperature']) + 1e-8\n",
" penalty_temperature = jax.nn.softplus(params['log_penalty_temperature']) + 1e-8\n",
" alpha = jax.nn.softplus(params['log_alpha']) + 1e-8\n",
" # assert not jnp.isnan(mean).any(), mean\n",
" # assert not jnp.isnan(temperature).any(), temperature\n",
" # assert not jnp.isnan(penalty_temperature).any(), penalty_temperature\n",
" # assert not jnp.isnan(alpha).any(), alpha\n",
"\n",
" dist = tfd.Independent(tfd.Normal(mean, jax.lax.stop_gradient(std * jnp.ones_like(mean))), reinterpreted_batch_ndims=1)\n",
" a_improvement = dist.sample(n_act_samples, seed=key)\n",
" # assert a_improvement.shape == (n_act_samples,) + mean.shape, a_improvement.shape\n",
" q_improvement = jax.lax.stop_gradient(jnp.sum(jnp.exp(-30 * (a_improvement - target)**2), -1))\n",
"\n",
" # assert jnp.isfinite(q_improvement).all(), q_improvement\n",
"\n",
" def compute_weights_and_temperature_loss(q_values, epsilon, temperature):\n",
" tempered_q_values = jax.lax.stop_gradient(q_values) / temperature\n",
" # assert not jnp.isnan(tempered_q_values).any(), tempered_q_values\n",
" q_logsumexp = jnp.mean(jnp.log(jnp.mean(jnp.exp(tempered_q_values), axis=0)))\n",
" # assert not jnp.isnan(q_logsumexp).any(), tempered_q_values\n",
" loss_temperature = jnp.mean(temperature * epsilon + temperature * q_logsumexp)\n",
" normalized_weights = jax.lax.stop_gradient(jax.nn.softmax(tempered_q_values, axis=0))\n",
" num_action_samples = normalized_weights.shape[0]\n",
" integrand = jnp.log(num_action_samples * normalized_weights + 1e-8)\n",
" non_parametric_kl = jnp.sum(normalized_weights * integrand, axis=0)\n",
" return normalized_weights, loss_temperature, non_parametric_kl\n",
"\n",
" normalized_weights, loss_temperature, non_parametric_kl = compute_weights_and_temperature_loss(q_improvement, 0.1, temperature)\n",
" metrics['loss_temperature'] = loss_temperature\n",
" metrics['non_parametric_kl'] = non_parametric_kl\n",
"\n",
" diff_out_of_bound = a_improvement - jnp.clip(a_improvement, -1.0, 1.0)\n",
" cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1)\n",
" penalty_normalized_weights, loss_penalty_temperature, penalty_non_parametric_kl = compute_weights_and_temperature_loss(cost_out_of_bound, 0.001, penalty_temperature)\n",
" metrics['loss_penalty_temperature'] = loss_penalty_temperature\n",
" metrics['penalty_non_parametric_kl'] = penalty_non_parametric_kl\n",
"\n",
" # print('SAMPLED ITEMS')\n",
" # print('actions: ', a_improvement)\n",
" # print('costs:', cost_out_of_bound)\n",
" # print('penalties: ', penalty_normalized_weights)\n",
"\n",
" # assert not jnp.isnan(cost_out_of_bound).any(), q_improvement\n",
" # assert not jnp.isnan(penalty_normalized_weights).any(), temperature\n",
" # assert not jnp.isnan(loss_penalty_temperature).any(), temperature\n",
"\n",
" # print('cost_out_of_bound', cost_out_of_bound.shape)\n",
" # print('penalty_normalized_weights', penalty_normalized_weights.shape)\n",
" # print('loss_penalty_temperature', loss_penalty_temperature.shape)\n",
" if use_objective_constraint:\n",
" loss_temperature += loss_penalty_temperature\n",
" normalized_weights += penalty_normalized_weights\n",
" else:\n",
" loss_temperature = loss_penalty_temperature\n",
" normalized_weights = penalty_normalized_weights\n",
"\n",
" def compute_parametric_kl_penalty_and_dual_loss(kl, alpha, epsilon):\n",
" loss_kl = jnp.sum(jax.lax.stop_gradient(alpha) * kl, -1)\n",
" loss_alpha = jnp.sum(alpha * (epsilon - jax.lax.stop_gradient(kl)), -1)\n",
" return loss_kl, loss_alpha\n",
" \n",
"\n",
" kl = slowdist.distribution.kl_divergence(dist.distribution)\n",
" loss_kl, loss_alpha = compute_parametric_kl_penalty_and_dual_loss(kl, alpha, 0.01)\n",
" metrics['parametric_kl'] = jnp.mean(kl)\n",
"\n",
" logpi = dist.log_prob(jax.lax.stop_gradient(a_improvement))\n",
" loss_dist = jnp.mean(-jnp.sum(normalized_weights * logpi, axis=0))\n",
" metrics['loss_dist'] = loss_dist\n",
" # assert not jnp.isnan(logpi).any(), temperature\n",
" # assert not jnp.isnan(loss_dist).any(), temperature\n",
" loss = loss_dist + loss_temperature\n",
" if use_slowdist_loss:\n",
" loss += loss_kl\n",
" loss += loss_alpha\n",
"\n",
" return loss, metrics\n",
"\n",
"\n",
"params = {\n",
" 'mean': jnp.array([2.0, 2.0]),\n",
" 'log_temperature': jnp.array([10.]),\n",
" 'log_penalty_temperature': jnp.array([10.]),\n",
" 'log_alpha': jnp.array([10., 10.]),}\n",
"opt_state = optimizer.init(params)\n",
"slowdist = tfd.Independent(tfd.Normal(\n",
" jax.lax.stop_gradient(params['mean']),\n",
" jax.lax.stop_gradient(std * jnp.ones_like(params['mean']))), 1)\n",
"\n",
"means = []\n",
"for i in range(20000):\n",
" if i % slowdist_update_freq == 0:\n",
" slowdist = tfd.Independent(tfd.Normal(\n",
" jax.lax.stop_gradient(params['mean']),\n",
" jax.lax.stop_gradient(std * jnp.ones_like(params['mean']))), 1)\n",
" _, key = jax.random.split(key)\n",
" (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, key, slowdist)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" means.append(params['mean'])\n",
" if i % 100 == 0:\n",
" print(f'Iteration {i}')\n",
" print(f'\\t Loss: {loss}')\n",
" print(f'\\t Params: {params}')\n",
" print(f'\\t Metrics: {metrics}')\n",
" print(f'\\t Updates: {updates}')\n",
" print(f'\\t Grad: {grads}')\n",
" print(f'\\t Slowdist: {slowdist.mean(), slowdist.stddev()}')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"a = np.array(means)\n",
"plt.scatter(a[:, 0], a[:, 1])\n",
"plt.scatter(a[0, 0], a[0, 1], color='g')\n",
"plt.scatter(target, target, color='r')\n",
"plt.plot([-1, 1], [1, 1], color='r')\n",
"plt.plot([-1, 1], [-1, -1], color='r')\n",
"plt.plot([-1, -1], [-1, 1], color='r')\n",
"plt.plot([1, 1], [-1, 1], color='r')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration 0\n",
"\t Loss: 10.634836196899414\n",
"\t Params: {'log_alpha_mean': Array([9.99, 9.99], dtype=float32), 'log_alpha_scale': Array([9.99, 9.99], dtype=float32), 'log_penalty_temperature': Array([10.01], dtype=float32), 'log_temperature': Array([9.99], dtype=float32), 'mean': Array([1.99, 2.01], dtype=float32), 'std': Array([0.99000007, 0.99000007], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.5347853, dtype=float32), 'loss_dist_scale': Array(5.5347853, dtype=float32), 'loss_penalty_temperature': Array(-1.6811458, dtype=float32), 'loss_temperature': Array(1.0264105, dtype=float32), 'non_parametric_kl': Array(0.00011204, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00380463, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00999992, -0.00999992], dtype=float32), 'log_alpha_scale': Array([-0.00999983, -0.00999983], dtype=float32), 'log_penalty_temperature': Array([0.0099999], dtype=float32), 'log_temperature': Array([-0.00999993], dtype=float32), 'mean': Array([-0.00999993, 0.00999993], dtype=float32), 'std': Array([-0.00999993, -0.00999993], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00999954, 0.00999954], dtype=float32), 'log_alpha_scale': Array([0.00099995, 0.00099995], dtype=float32), 'log_penalty_temperature': Array([-0.00280459], dtype=float32), 'log_temperature': Array([0.09988341], dtype=float32), 'mean': Array([ 0.19940747, -0.12546885], dtype=float32), 'std': Array([0.17432547, 0.10761261], dtype=float32)}\n",
"\t Slowdist: (Array([2., 2.], dtype=float32), Array([1., 1.], dtype=float32))\n",
"Iteration 100\n",
"\t Loss: 10.501836776733398\n",
"\t Params: {'log_alpha_mean': Array([8.9898405, 8.990711 ], dtype=float32), 'log_alpha_scale': Array([8.972935, 9.004888], dtype=float32), 'log_penalty_temperature': Array([10.947263], dtype=float32), 'log_temperature': Array([8.990028], dtype=float32), 'mean': Array([2.0018919, 1.9789495], dtype=float32), 'std': Array([0.98237216, 1.003999 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.4963284, dtype=float32), 'loss_dist_scale': Array(5.4963284, dtype=float32), 'loss_penalty_temperature': Array(-1.5888923, dtype=float32), 'loss_temperature': Array(0.9000655, dtype=float32), 'non_parametric_kl': Array(7.0234364e-08, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00356952, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.01003395, -0.00994753], dtype=float32), 'log_alpha_scale': Array([-0.01016829, -0.01042152], dtype=float32), 'log_penalty_temperature': Array([0.00926764], dtype=float32), 'log_temperature': Array([-0.00999924], dtype=float32), 'mean': Array([-0.00264046, 0.00381722], dtype=float32), 'std': Array([-5.8914506e-05, -5.2382553e-04], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00999877, 0.00999877], dtype=float32), 'log_alpha_scale': Array([0.00099987, 0.00099988], dtype=float32), 'log_penalty_temperature': Array([-0.00256951], dtype=float32), 'log_temperature': Array([0.09998763], dtype=float32), 'mean': Array([ 0.27481824, -0.02431485], dtype=float32), 'std': Array([0.0091095 , 0.29571533], dtype=float32)}\n",
"\t Slowdist: (Array([2.0045323, 1.9751323], dtype=float32), Array([0.98243105, 1.0045228 ], dtype=float32))\n",
"Iteration 200\n",
"\t Loss: 11.171329498291016\n",
"\t Params: {'log_alpha_mean': Array([7.9893394, 7.9916687], dtype=float32), 'log_alpha_scale': Array([7.9633045, 7.9910583], dtype=float32), 'log_penalty_temperature': Array([11.798137], dtype=float32), 'log_temperature': Array([7.990184], dtype=float32), 'mean': Array([2.0077944, 1.9766798], dtype=float32), 'std': Array([1.0159093, 1.0044231], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.8615475, dtype=float32), 'loss_dist_scale': Array(5.8615475, dtype=float32), 'loss_penalty_temperature': Array(-1.5463716, dtype=float32), 'loss_temperature': Array(0.8186146, dtype=float32), 'non_parametric_kl': Array(0.00012675, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00307883, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00997329, -0.00996955], dtype=float32), 'log_alpha_scale': Array([-0.01004053, -0.00998015], dtype=float32), 'log_penalty_temperature': Array([0.00844316], dtype=float32), 'log_temperature': Array([-0.00999664], dtype=float32), 'mean': Array([-0.00138968, 0.00061454], dtype=float32), 'std': Array([-0.00033521, 0.00300407], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00999664, 0.00999665], dtype=float32), 'log_alpha_scale': Array([0.00099966, 0.00099967], dtype=float32), 'log_penalty_temperature': Array([-0.00207885], dtype=float32), 'log_temperature': Array([0.09983971], dtype=float32), 'mean': Array([0.31438398, 0.30233333], dtype=float32), 'std': Array([-0.21620178, -0.08162832], dtype=float32)}\n",
"\t Slowdist: (Array([2.0091841, 1.9760653], dtype=float32), Array([1.0162445, 1.0014191], dtype=float32))\n",
"Iteration 300\n",
"\t Loss: 10.307255744934082\n",
"\t Params: {'log_alpha_mean': Array([6.988981, 6.993453], dtype=float32), 'log_alpha_scale': Array([6.9325366, 6.979272 ], dtype=float32), 'log_penalty_temperature': Array([12.571055], dtype=float32), 'log_temperature': Array([6.9905663], dtype=float32), 'mean': Array([1.9995255, 1.9818168], dtype=float32), 'std': Array([1.0207626, 1.0161766], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.5323844, dtype=float32), 'loss_dist_scale': Array(5.5323844, dtype=float32), 'loss_penalty_temperature': Array(-1.611834, dtype=float32), 'loss_temperature': Array(0.70034486, dtype=float32), 'non_parametric_kl': Array(-3.1834134e-08, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00261007, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.01001551, -0.0099756 ], dtype=float32), 'log_alpha_scale': Array([-0.01025955, -0.01043388], dtype=float32), 'log_penalty_temperature': Array([0.00763704], dtype=float32), 'log_temperature': Array([-0.00999331], dtype=float32), 'mean': Array([0.00203451, 0.00038188], dtype=float32), 'std': Array([0.00180397, 0.00023839], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00999088, 0.00999092], dtype=float32), 'log_alpha_scale': Array([0.00099904, 0.00099908], dtype=float32), 'log_penalty_temperature': Array([-0.00161011], dtype=float32), 'log_temperature': Array([0.099909], dtype=float32), 'mean': Array([-0.15064903, 0.33702782], dtype=float32), 'std': Array([0.30689788, 0.11063433], dtype=float32)}\n",
"\t Slowdist: (Array([1.997491, 1.981435], dtype=float32), Array([1.0189586, 1.0159382], dtype=float32))\n",
"Iteration 400\n",
"\t Loss: 10.843290328979492\n",
"\t Params: {'log_alpha_mean': Array([5.9895415, 5.9943943], dtype=float32), 'log_alpha_scale': Array([5.943256 , 5.9878726], dtype=float32), 'log_penalty_temperature': Array([13.255847], dtype=float32), 'log_temperature': Array([5.9917774], dtype=float32), 'mean': Array([1.9738094, 1.9214609], dtype=float32), 'std': Array([1.0271099, 1.0046875], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.905237, dtype=float32), 'loss_dist_scale': Array(5.905237, dtype=float32), 'loss_penalty_temperature': Array(-1.7025176, dtype=float32), 'loss_temperature': Array(0.60328835, dtype=float32), 'non_parametric_kl': Array(1.1017386e-05, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00299899, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.01000584, -0.01003526], dtype=float32), 'log_alpha_scale': Array([-0.00982182, -0.01015128], dtype=float32), 'log_penalty_temperature': Array([0.00677194], dtype=float32), 'log_temperature': Array([-0.00998274], dtype=float32), 'mean': Array([ 0.00045743, -0.00172322], dtype=float32), 'std': Array([ 4.3866476e-03, -1.1541058e-05], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00997526, 0.00997538], dtype=float32), 'log_alpha_scale': Array([0.00099741, 0.00099752], dtype=float32), 'log_penalty_temperature': Array([-0.00199902], dtype=float32), 'log_temperature': Array([0.09974231], dtype=float32), 'mean': Array([-0.10582384, -0.05908114], dtype=float32), 'std': Array([-0.41858244, 0.07739139], dtype=float32)}\n",
"\t Slowdist: (Array([1.973352, 1.923184], dtype=float32), Array([1.0227232, 1.0046991], dtype=float32))\n",
"Iteration 500\n",
"\t Loss: 10.395594596862793\n",
"\t Params: {'log_alpha_mean': Array([4.9947243, 5.000006 ], dtype=float32), 'log_alpha_scale': Array([4.9454246, 5.0715246], dtype=float32), 'log_penalty_temperature': Array([13.882815], dtype=float32), 'log_temperature': Array([4.995106], dtype=float32), 'mean': Array([1.9434159, 1.8974422], dtype=float32), 'std': Array([1.0186335, 1.0160666], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.629676, dtype=float32), 'loss_dist_scale': Array(5.629676, dtype=float32), 'loss_penalty_temperature': Array(-1.4755543, dtype=float32), 'loss_temperature': Array(0.50146896, dtype=float32), 'non_parametric_kl': Array(5.776201e-08, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00210682, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00996913, -0.00976018], dtype=float32), 'log_alpha_scale': Array([-0.01021249, -0.00995974], dtype=float32), 'log_penalty_temperature': Array([0.00581957], dtype=float32), 'log_temperature': Array([-0.00995071], dtype=float32), 'mean': Array([-0.00067231, 0.0018401 ], dtype=float32), 'std': Array([0.00267545, 0.00258026], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00993338, 0.00993372], dtype=float32), 'log_alpha_scale': Array([0.00099301, 0.00099383], dtype=float32), 'log_penalty_temperature': Array([-0.00110682], dtype=float32), 'log_temperature': Array([0.09933402], dtype=float32), 'mean': Array([0.41739544, 0.00244189], dtype=float32), 'std': Array([-0.00261962, 0.20891345], dtype=float32)}\n",
"\t Slowdist: (Array([1.9440882, 1.8956021], dtype=float32), Array([1.0159581, 1.0134863], dtype=float32))\n",
"Iteration 600\n",
"\t Loss: 9.812423706054688\n",
"\t Params: {'log_alpha_mean': Array([4.0005674, 4.0051894], dtype=float32), 'log_alpha_scale': Array([3.9680495, 4.037309 ], dtype=float32), 'log_penalty_temperature': Array([14.470889], dtype=float32), 'log_temperature': Array([4.004064], dtype=float32), 'mean': Array([1.9193609, 1.8731921], dtype=float32), 'std': Array([1.0311558, 0.9853905], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.4563446, dtype=float32), 'loss_dist_scale': Array(5.4563446, dtype=float32), 'loss_penalty_temperature': Array(-1.6040164, dtype=float32), 'loss_temperature': Array(0.4150761, dtype=float32), 'non_parametric_kl': Array(0.00024706, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00168489, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00989726, -0.00993025], dtype=float32), 'log_alpha_scale': Array([-0.00884643, -0.0102415 ], dtype=float32), 'log_penalty_temperature': Array([0.00505015], dtype=float32), 'log_temperature': Array([-0.00986623], dtype=float32), 'mean': Array([-0.00374791, -0.00048407], dtype=float32), 'std': Array([-0.00455066, -0.00316194], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00982198, 0.00982279], dtype=float32), 'log_alpha_scale': Array([0.0009816 , 0.00098283], dtype=float32), 'log_penalty_temperature': Array([-0.00068493], dtype=float32), 'log_temperature': Array([0.09798305], dtype=float32), 'mean': Array([ 0.02225181, -0.12331976], dtype=float32), 'std': Array([0.39881384, 0.12143445], dtype=float32)}\n",
"\t Slowdist: (Array([1.9231088, 1.8736762], dtype=float32), Array([1.0357065, 0.9885524], dtype=float32))\n",
"Iteration 700\n",
"\t Loss: 9.773394584655762\n",
"\t Params: {'log_alpha_mean': Array([3.0240288, 3.0242302], dtype=float32), 'log_alpha_scale': Array([3.0639296, 3.1058667], dtype=float32), 'log_penalty_temperature': Array([14.951046], dtype=float32), 'log_temperature': Array([3.0276752], dtype=float32), 'mean': Array([1.855441 , 1.8280596], dtype=float32), 'std': Array([1.0362141, 0.9911518], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.4080253, dtype=float32), 'loss_dist_scale': Array(5.4080253, dtype=float32), 'loss_penalty_temperature': Array(-1.4297508, dtype=float32), 'loss_temperature': Array(0.31920096, dtype=float32), 'non_parametric_kl': Array(0.00024687, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00129923, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00965122, -0.00967342], dtype=float32), 'log_alpha_scale': Array([-0.00959231, -0.01005476], dtype=float32), 'log_penalty_temperature': Array([0.00442179], dtype=float32), 'log_temperature': Array([-0.00962961], dtype=float32), 'mean': Array([-0.00245786, 0.00038475], dtype=float32), 'std': Array([-0.00392805, 0.00177201], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00954073, 0.00954082], dtype=float32), 'log_alpha_scale': Array([0.00095579, 0.00095754], dtype=float32), 'log_penalty_temperature': Array([-0.00029927], dtype=float32), 'log_temperature': Array([0.09518743], dtype=float32), 'mean': Array([0.17081639, 0.11142206], dtype=float32), 'std': Array([0.45378053, 0.18009579], dtype=float32)}\n",
"\t Slowdist: (Array([1.8578988, 1.8276747], dtype=float32), Array([1.0401422, 0.9893798], dtype=float32))\n",
"Iteration 800\n",
"\t Loss: 10.055827140808105\n",
"\t Params: {'log_alpha_mean': Array([2.0888386, 2.0819914], dtype=float32), 'log_alpha_scale': Array([2.1432326, 2.1323292], dtype=float32), 'log_penalty_temperature': Array([15.426739], dtype=float32), 'log_temperature': Array([2.0883243], dtype=float32), 'mean': Array([1.7852235, 1.7838103], dtype=float32), 'std': Array([1.0529678 , 0.97169656], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.56331, dtype=float32), 'loss_dist_scale': Array(5.56331, dtype=float32), 'loss_penalty_temperature': Array(-1.346628, dtype=float32), 'loss_temperature': Array(0.22710684, dtype=float32), 'non_parametric_kl': Array(0.00018226, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00164119, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00909793, -0.00911971], dtype=float32), 'log_alpha_scale': Array([-0.00960647, -0.00892896], dtype=float32), 'log_penalty_temperature': Array([0.00495098], dtype=float32), 'log_temperature': Array([-0.00909024], dtype=float32), 'mean': Array([0.00113627, 0.00039891], dtype=float32), 'std': Array([-0.00063927, -0.00023052], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00890702, 0.00890036], dtype=float32), 'log_alpha_scale': Array([0.00089593, 0.00089485], dtype=float32), 'log_penalty_temperature': Array([-0.00064125], dtype=float32), 'log_temperature': Array([0.0889027], dtype=float32), 'mean': Array([0.1821757 , 0.13760634], dtype=float32), 'std': Array([0.05844939, 0.265746 ], dtype=float32)}\n",
"\t Slowdist: (Array([1.7840872, 1.7834114], dtype=float32), Array([1.0536071 , 0.97192705], dtype=float32))\n",
"Iteration 900\n",
"\t Loss: 10.414846420288086\n",
"\t Params: {'log_alpha_mean': Array([1.2190869, 1.2119188], dtype=float32), 'log_alpha_scale': Array([1.3263105, 1.3453943], dtype=float32), 'log_penalty_temperature': Array([15.860594], dtype=float32), 'log_temperature': Array([1.2309864], dtype=float32), 'mean': Array([1.7702087, 1.7440857], dtype=float32), 'std': Array([1.0903659, 1.0317504], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.7435527, dtype=float32), 'loss_dist_scale': Array(5.7435527, dtype=float32), 'loss_penalty_temperature': Array(-1.268894, dtype=float32), 'loss_temperature': Array(0.16385984, dtype=float32), 'non_parametric_kl': Array(0.00283572, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00157666, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00810354, -0.00821495], dtype=float32), 'log_alpha_scale': Array([-0.00761957, -0.00028315], dtype=float32), 'log_penalty_temperature': Array([0.00437291], dtype=float32), 'log_temperature': Array([-0.00797245], dtype=float32), 'mean': Array([-0.00025294, -0.00126575], dtype=float32), 'std': Array([ 0.00027253, -0.00415462], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00773326, 0.00772087], dtype=float32), 'log_alpha_scale': Array([0.00079149, 0.00079342], dtype=float32), 'log_penalty_temperature': Array([-0.00057664], dtype=float32), 'log_temperature': Array([0.07533956], dtype=float32), 'mean': Array([0.32433525, 0.23652917], dtype=float32), 'std': Array([-0.1077466, 0.4517902], dtype=float32)}\n",
"\t Slowdist: (Array([1.7704617, 1.7453514], dtype=float32), Array([1.0900934, 1.0359051], dtype=float32))\n",
"Iteration 1000\n",
"\t Loss: 10.428284645080566\n",
"\t Params: {'log_alpha_mean': Array([0.47500676, 0.46594504], dtype=float32), 'log_alpha_scale': Array([0.56197387, 0.76214904], dtype=float32), 'log_penalty_temperature': Array([16.298033], dtype=float32), 'log_temperature': Array([0.50975794], dtype=float32), 'mean': Array([1.6426165, 1.6870136], dtype=float32), 'std': Array([1.065806 , 1.0490845], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.843442, dtype=float32), 'loss_dist_scale': Array(5.843442, dtype=float32), 'loss_penalty_temperature': Array(-1.3940485, dtype=float32), 'loss_temperature': Array(0.11408404, dtype=float32), 'non_parametric_kl': Array(0.00438417, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00147438, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.0067487 , -0.00663688], dtype=float32), 'log_alpha_scale': Array([-0.00682014, -0.0060323 ], dtype=float32), 'log_penalty_temperature': Array([0.00372928], dtype=float32), 'log_temperature': Array([-0.00640762], dtype=float32), 'mean': Array([-0.00056 , -0.0004898], dtype=float32), 'std': Array([ 0.0005746 , -0.00354638], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00618162, 0.00615995], dtype=float32), 'log_alpha_scale': Array([0.00063848, 0.00068313], dtype=float32), 'log_penalty_temperature': Array([-0.00047437], dtype=float32), 'log_temperature': Array([0.05987953], dtype=float32), 'mean': Array([ 0.03570448, -0.02145896], dtype=float32), 'std': Array([ 0.16139364, -0.04688966], dtype=float32)}\n",
"\t Slowdist: (Array([1.6431766, 1.6875035], dtype=float32), Array([1.0652314, 1.0526309], dtype=float32))\n",
"Iteration 1100\n",
"\t Loss: 10.393377304077148\n",
"\t Params: {'log_alpha_mean': Array([-0.11661696, -0.14229466], dtype=float32), 'log_alpha_scale': Array([-0.04756668, 0.36868355], dtype=float32), 'log_penalty_temperature': Array([16.647552], dtype=float32), 'log_temperature': Array([-0.01090434], dtype=float32), 'mean': Array([1.5070807, 1.5932615], dtype=float32), 'std': Array([1.1037533, 1.0480413], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.703573, dtype=float32), 'loss_dist_scale': Array(5.703573, dtype=float32), 'loss_penalty_temperature': Array(-1.1448295, dtype=float32), 'loss_temperature': Array(0.11682957, dtype=float32), 'non_parametric_kl': Array(0.0506217, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00107339, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.0054071 , -0.00539334], dtype=float32), 'log_alpha_scale': Array([-0.0037757, -0.0066846], dtype=float32), 'log_penalty_temperature': Array([0.00297134], dtype=float32), 'log_temperature': Array([-0.00356754], dtype=float32), 'mean': Array([ 0.00225757, -0.00110431], dtype=float32), 'std': Array([-0.00127008, 0.0013447 ], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00472226, 0.00465828], dtype=float32), 'log_alpha_scale': Array([0.00048905, 0.00059276], dtype=float32), 'log_penalty_temperature': Array([-7.34051e-05], dtype=float32), 'log_temperature': Array([0.02459866], dtype=float32), 'mean': Array([0.0935921, 0.2123504], dtype=float32), 'std': Array([0.21292293, 0.27811265], dtype=float32)}\n",
"\t Slowdist: (Array([1.5048231, 1.5943658], dtype=float32), Array([1.1050234, 1.0466967], dtype=float32))\n",
"Iteration 1200\n",
"\t Loss: 11.944158554077148\n",
"\t Params: {'log_alpha_mean': Array([-0.5763378, -0.6066979], dtype=float32), 'log_alpha_scale': Array([-0.51761127, -0.07573829], dtype=float32), 'log_penalty_temperature': Array([16.980047], dtype=float32), 'log_temperature': Array([-0.21937615], dtype=float32), 'mean': Array([1.2289373, 1.4086984], dtype=float32), 'std': Array([1.1992322, 1.1608555], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(6.3966117, dtype=float32), 'loss_dist_scale': Array(6.3966117, dtype=float32), 'loss_penalty_temperature': Array(-1.0252453, dtype=float32), 'loss_temperature': Array(0.16621596, dtype=float32), 'non_parametric_kl': Array(0.12052727, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00136078, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00339909, -0.00437006], dtype=float32), 'log_alpha_scale': Array([-0.00374575, -0.00377474], dtype=float32), 'log_penalty_temperature': Array([0.00335167], dtype=float32), 'log_temperature': Array([-0.00056251], dtype=float32), 'mean': Array([-0.00148898, -0.00209857], dtype=float32), 'std': Array([ 0.00148613, -0.00135106], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00360559, 0.00353811], dtype=float32), 'log_alpha_scale': Array([0.00037429, 0.00048202], dtype=float32), 'log_penalty_temperature': Array([-0.0003608], dtype=float32), 'log_temperature': Array([-0.00914509], dtype=float32), 'mean': Array([0.20790292, 0.18558462], dtype=float32), 'std': Array([-0.08903313, -0.0103488 ], dtype=float32)}\n",
"\t Slowdist: (Array([1.2304262, 1.410797 ], dtype=float32), Array([1.197746 , 1.1622066], dtype=float32))\n",
"Iteration 1300\n",
"\t Loss: 13.061481475830078\n",
"\t Params: {'log_alpha_mean': Array([-0.9459383, -0.9843066], dtype=float32), 'log_alpha_scale': Array([-0.8510949 , -0.48921522], dtype=float32), 'log_penalty_temperature': Array([17.347765], dtype=float32), 'log_temperature': Array([-0.13230795], dtype=float32), 'mean': Array([0.8780919, 1.152789 ], dtype=float32), 'std': Array([1.3100417, 1.3071972], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(6.99171, dtype=float32), 'loss_dist_scale': Array(6.99171, dtype=float32), 'loss_penalty_temperature': Array(-1.1008959, dtype=float32), 'loss_temperature': Array(0.17164816, dtype=float32), 'non_parametric_kl': Array(0.07910566, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00157741, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00322689, -0.00307982], dtype=float32), 'log_alpha_scale': Array([-0.0026935 , -0.00232858], dtype=float32), 'log_penalty_temperature': Array([0.00347433], dtype=float32), 'log_temperature': Array([0.00197124], dtype=float32), 'mean': Array([-0.00393439, -0.00330106], dtype=float32), 'std': Array([-0.0012877 , 0.00077658], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00280353, 0.00272648], dtype=float32), 'log_alpha_scale': Array([0.00029977, 0.00038063], dtype=float32), 'log_penalty_temperature': Array([-0.00057742], dtype=float32), 'log_temperature': Array([0.00974678], dtype=float32), 'mean': Array([-0.13022321, -0.00566802], dtype=float32), 'std': Array([ 0.09032083, -0.45699072], dtype=float32)}\n",
"\t Slowdist: (Array([0.88202626, 1.15609 ], dtype=float32), Array([1.3113294, 1.3064207], dtype=float32))\n",
"Iteration 1400\n",
"\t Loss: 13.246981620788574\n",
"\t Params: {'log_alpha_mean': Array([-1.2637354, -1.2861173], dtype=float32), 'log_alpha_scale': Array([-1.163773 , -0.8891531], dtype=float32), 'log_penalty_temperature': Array([17.625298], dtype=float32), 'log_temperature': Array([0.06646895], dtype=float32), 'mean': Array([0.58107066, 0.83711004], dtype=float32), 'std': Array([1.3478258, 1.3444604], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(6.893543, dtype=float32), 'loss_dist_scale': Array(6.893543, dtype=float32), 'loss_penalty_temperature': Array(-0.8154192, dtype=float32), 'loss_temperature': Array(0.26975632, dtype=float32), 'non_parametric_kl': Array(0.11344148, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.0012234, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00272009, -0.00278943], dtype=float32), 'log_alpha_scale': Array([-0.00329222, -0.00427353], dtype=float32), 'log_penalty_temperature': Array([0.00218693], dtype=float32), 'log_temperature': Array([0.00140221], dtype=float32), 'mean': Array([-0.00216168, -0.00104239], dtype=float32), 'std': Array([0.00111423, 0.00174302], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00220799, 0.00216984], dtype=float32), 'log_alpha_scale': Array([0.00023858, 0.00029217], dtype=float32), 'log_penalty_temperature': Array([-0.0002234], dtype=float32), 'log_temperature': Array([-0.00693923], dtype=float32), 'mean': Array([0.07961895, 0.06169013], dtype=float32), 'std': Array([-0.05584896, 0.00676262], dtype=float32)}\n",
"\t Slowdist: (Array([0.58323234, 0.8381524 ], dtype=float32), Array([1.3467115, 1.3427174], dtype=float32))\n",
"Iteration 1500\n",
"\t Loss: 13.0104341506958\n",
"\t Params: {'log_alpha_mean': Array([-1.5256542, -1.5527204], dtype=float32), 'log_alpha_scale': Array([-1.4252537, -1.2386018], dtype=float32), 'log_penalty_temperature': Array([17.651665], dtype=float32), 'log_temperature': Array([0.20024928], dtype=float32), 'mean': Array([0.23097302, 0.5666039 ], dtype=float32), 'std': Array([1.2613862, 1.3232708], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(6.639682, dtype=float32), 'loss_dist_scale': Array(6.639682, dtype=float32), 'loss_penalty_temperature': Array(-0.6181773, dtype=float32), 'loss_temperature': Array(0.34488016, dtype=float32), 'non_parametric_kl': Array(0.12784487, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00081711, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00253495, -0.00240613], dtype=float32), 'log_alpha_scale': Array([-0.00233881, -0.00278399], dtype=float32), 'log_penalty_temperature': Array([-0.00117721], dtype=float32), 'log_temperature': Array([0.00049501], dtype=float32), 'mean': Array([-0.0031742 , -0.00022893], dtype=float32), 'std': Array([-0.00060993, -0.00232036], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00179003, 0.00175041], dtype=float32), 'log_alpha_scale': Array([0.00019421, 0.00022516], dtype=float32), 'log_penalty_temperature': Array([0.00018289], dtype=float32), 'log_temperature': Array([-0.01530839], dtype=float32), 'mean': Array([0.24209608, 0.10759301], dtype=float32), 'std': Array([-0.18880594, 0.27808654], dtype=float32)}\n",
"\t Slowdist: (Array([0.23414722, 0.56683284], dtype=float32), Array([1.261996 , 1.3255912], dtype=float32))\n",
"Iteration 1600\n",
"\t Loss: 12.547872543334961\n",
"\t Params: {'log_alpha_mean': Array([-1.7588459, -1.7869983], dtype=float32), 'log_alpha_scale': Array([-1.6077781, -1.5071139], dtype=float32), 'log_penalty_temperature': Array([17.392824], dtype=float32), 'log_temperature': Array([0.3215012], dtype=float32), 'mean': Array([-0.05936042, 0.36424693], dtype=float32), 'std': Array([1.1156728, 1.2697328], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(6.3736124, dtype=float32), 'loss_dist_scale': Array(6.3736124, dtype=float32), 'loss_penalty_temperature': Array(-0.5237775, dtype=float32), 'loss_temperature': Array(0.32089895, dtype=float32), 'non_parametric_kl': Array(0.13612637, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00071063, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00222026, -0.00225564], dtype=float32), 'log_alpha_scale': Array([ 0.00164221, -0.00285414], dtype=float32), 'log_penalty_temperature': Array([-0.00352807], dtype=float32), 'log_temperature': Array([0.00084911], dtype=float32), 'mean': Array([-0.00507259, -0.00177004], dtype=float32), 'std': Array([ 0.00315816, -0.00362787], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00147213, 0.00143719], dtype=float32), 'log_alpha_scale': Array([0.00016667, 0.00018179], dtype=float32), 'log_penalty_temperature': Array([0.00028939], dtype=float32), 'log_temperature': Array([-0.02093469], dtype=float32), 'mean': Array([0.24120304, 0.19993056], dtype=float32), 'std': Array([-0.02220583, 0.01736975], dtype=float32)}\n",
"\t Slowdist: (Array([-0.05428784, 0.36601698], dtype=float32), Array([1.1125146, 1.2733607], dtype=float32))\n",
"Iteration 1700\n",
"\t Loss: 11.571065902709961\n",
"\t Params: {'log_alpha_mean': Array([-1.9648342, -1.9892168], dtype=float32), 'log_alpha_scale': Array([-1.6330125, -1.7388906], dtype=float32), 'log_penalty_temperature': Array([16.921747], dtype=float32), 'log_temperature': Array([0.40935376], dtype=float32), 'mean': Array([-0.2143139 , 0.09657985], dtype=float32), 'std': Array([0.9693599, 1.1729561], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.7141495, dtype=float32), 'loss_dist_scale': Array(5.7141495, dtype=float32), 'loss_penalty_temperature': Array(-0.3204788, dtype=float32), 'loss_temperature': Array(0.460307, dtype=float32), 'non_parametric_kl': Array(0.16240694, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00040189, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.0017972 , -0.00163566], dtype=float32), 'log_alpha_scale': Array([-0.00116086, -0.00246525], dtype=float32), 'log_penalty_temperature': Array([-0.00602292], dtype=float32), 'log_temperature': Array([0.00095901], dtype=float32), 'mean': Array([ 0.00613796, -0.00044421], dtype=float32), 'std': Array([-7.146979e-05, -4.319504e-03], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00123139, 0.00120513], dtype=float32), 'log_alpha_scale': Array([0.00016358, 0.00014977], dtype=float32), 'log_penalty_temperature': Array([0.00059809], dtype=float32), 'log_temperature': Array([-0.03748804], dtype=float32), 'mean': Array([-0.09718732, -0.15199846], dtype=float32), 'std': Array([-0.06344247, 0.4360422 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.22045186, 0.09702406], dtype=float32), Array([0.96943134, 1.1772757 ], dtype=float32))\n",
"Iteration 1800\n",
"\t Loss: 10.886012077331543\n",
"\t Params: {'log_alpha_mean': Array([-2.150007 , -2.1710804], dtype=float32), 'log_alpha_scale': Array([-1.5323162, -1.7766854], dtype=float32), 'log_penalty_temperature': Array([16.2125], dtype=float32), 'log_temperature': Array([0.5223798], dtype=float32), 'mean': Array([-0.3659829 , -0.11126229], dtype=float32), 'std': Array([0.7850756, 0.9995463], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(5.3692656, dtype=float32), 'loss_dist_scale': Array(5.3692656, dtype=float32), 'loss_penalty_temperature': Array(-0.2714685, dtype=float32), 'loss_temperature': Array(0.41641247, dtype=float32), 'non_parametric_kl': Array(0.08041066, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00032137, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00180764, -0.00178288], dtype=float32), 'log_alpha_scale': Array([0.00374182, 0.00242694], dtype=float32), 'log_penalty_temperature': Array([-0.00792233], dtype=float32), 'log_temperature': Array([0.00024397], dtype=float32), 'mean': Array([-0.00089479, 0.00077892], dtype=float32), 'std': Array([0.00186791, 0.00145471], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.001045 , 0.00102542], dtype=float32), 'log_alpha_scale': Array([0.00017711, 0.00014441], dtype=float32), 'log_penalty_temperature': Array([0.0006786], dtype=float32), 'log_temperature': Array([0.01229519], dtype=float32), 'mean': Array([ 0.04270515, -0.3861392 ], dtype=float32), 'std': Array([-0.11069083, -0.28594434], dtype=float32)}\n",
"\t Slowdist: (Array([-0.3650881 , -0.11204121], dtype=float32), Array([0.7832077, 0.9980916], dtype=float32))\n",
"Iteration 1900\n",
"\t Loss: 9.107159614562988\n",
"\t Params: {'log_alpha_mean': Array([-2.3197582, -2.3397439], dtype=float32), 'log_alpha_scale': Array([-1.4396453, -1.8691139], dtype=float32), 'log_penalty_temperature': Array([15.370003], dtype=float32), 'log_temperature': Array([0.59381837], dtype=float32), 'mean': Array([-0.42314735, -0.22344188], dtype=float32), 'std': Array([0.64479566, 0.91510457], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(4.4020996, dtype=float32), 'loss_dist_scale': Array(4.4020996, dtype=float32), 'loss_penalty_temperature': Array(-0.16732755, dtype=float32), 'loss_temperature': Array(0.46807468, dtype=float32), 'non_parametric_kl': Array(0.09569712, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00019419, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00166317, -0.00149785], dtype=float32), 'log_alpha_scale': Array([0.01352011, 0.00115625], dtype=float32), 'log_penalty_temperature': Array([-0.00894727], dtype=float32), 'log_temperature': Array([5.8931255e-05], dtype=float32), 'mean': Array([-0.00137651, 0.00207637], dtype=float32), 'std': Array([-0.00076207, 0.0026572 ], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00089635, 0.00088005], dtype=float32), 'log_alpha_scale': Array([0.00018951, 0.00013351], dtype=float32), 'log_penalty_temperature': Array([0.00080575], dtype=float32), 'log_temperature': Array([0.00277202], dtype=float32), 'mean': Array([-0.03199971, -0.00045451], dtype=float32), 'std': Array([ 0.70049524, -0.02407002], dtype=float32)}\n",
"\t Slowdist: (Array([-0.42177084, -0.22551826], dtype=float32), Array([0.6455577, 0.9124474], dtype=float32))\n",
"Iteration 2000\n",
"\t Loss: 8.340788841247559\n",
"\t Params: {'log_alpha_mean': Array([-2.456947 , -2.4883978], dtype=float32), 'log_alpha_scale': Array([-0.83445656, -1.9364644 ], dtype=float32), 'log_penalty_temperature': Array([14.427636], dtype=float32), 'log_temperature': Array([0.68290627], dtype=float32), 'mean': Array([-0.5906134 , -0.36252683], dtype=float32), 'std': Array([0.59498936, 0.797119 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(3.9837809, dtype=float32), 'loss_dist_scale': Array(3.9837809, dtype=float32), 'loss_penalty_temperature': Array(-0.16984057, dtype=float32), 'loss_temperature': Array(0.54095536, dtype=float32), 'non_parametric_kl': Array(0.08146161, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00020885, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00147673, -0.00149742], dtype=float32), 'log_alpha_scale': Array([ 0.0188882 , -0.00165148], dtype=float32), 'log_penalty_temperature': Array([-0.00956709], dtype=float32), 'log_temperature': Array([-9.4544586e-05], dtype=float32), 'mean': Array([-0.00633501, 0.00218428], dtype=float32), 'std': Array([ 0.00682069, -0.00095406], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00079039, 0.00076782], dtype=float32), 'log_alpha_scale': Array([0.00029873, 0.00012622], dtype=float32), 'log_penalty_temperature': Array([0.00079113], dtype=float32), 'log_temperature': Array([0.01231694], dtype=float32), 'mean': Array([0.36718503, 0.2499914 ], dtype=float32), 'std': Array([-0.17570448, 0.5790248 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.5842784, -0.3647111], dtype=float32), Array([0.5881687 , 0.79807305], dtype=float32))\n",
"Iteration 2100\n",
"\t Loss: 6.896862506866455\n",
"\t Params: {'log_alpha_mean': Array([-2.5990117, -2.625048 ], dtype=float32), 'log_alpha_scale': Array([ 0.5877083, -1.9314023], dtype=float32), 'log_penalty_temperature': Array([13.457586], dtype=float32), 'log_temperature': Array([0.7340979], dtype=float32), 'mean': Array([-0.6412752, -0.4639231], dtype=float32), 'std': Array([0.5442644 , 0.64757186], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(3.2003965, dtype=float32), 'loss_dist_scale': Array(3.2003965, dtype=float32), 'loss_penalty_temperature': Array(-0.11266573, dtype=float32), 'loss_temperature': Array(0.6061672, dtype=float32), 'non_parametric_kl': Array(0.078012, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00014412, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00139315, -0.00136317], dtype=float32), 'log_alpha_scale': Array([0.02413624, 0.00151237], dtype=float32), 'log_penalty_temperature': Array([-0.00997739], dtype=float32), 'log_temperature': Array([-3.4192726e-05], dtype=float32), 'mean': Array([9.4851508e-05, 2.1299592e-03], dtype=float32), 'std': Array([ 0.0016123 , -0.00591179], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00069292, 0.0006763 ], dtype=float32), 'log_alpha_scale': Array([0.00063728, 0.00012643], dtype=float32), 'log_penalty_temperature': Array([0.00085584], dtype=float32), 'log_temperature': Array([0.01485759], dtype=float32), 'mean': Array([ 0.27730727, -0.39022574], dtype=float32), 'std': Array([0.6570852, 0.6844051], dtype=float32)}\n",
"\t Slowdist: (Array([-0.64137006, -0.46605307], dtype=float32), Array([0.5426521, 0.6534836], dtype=float32))\n",
"Iteration 2200\n",
"\t Loss: 5.311196804046631\n",
"\t Params: {'log_alpha_mean': Array([-2.7324965, -2.7575915], dtype=float32), 'log_alpha_scale': Array([ 1.4546008, -1.3423479], dtype=float32), 'log_penalty_temperature': Array([12.440416], dtype=float32), 'log_temperature': Array([0.76946247], dtype=float32), 'mean': Array([-0.6691051, -0.5819451], dtype=float32), 'std': Array([0.46756318, 0.5527357 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(2.276059, dtype=float32), 'loss_dist_scale': Array(2.276059, dtype=float32), 'loss_penalty_temperature': Array(-0.06675544, dtype=float32), 'loss_temperature': Array(0.8226953, dtype=float32), 'non_parametric_kl': Array(0.11348477, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(9.156967e-05, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00129202, -0.00121036], dtype=float32), 'log_alpha_scale': Array([0.00484031, 0.00192277], dtype=float32), 'log_penalty_temperature': Array([-0.01026074], dtype=float32), 'log_temperature': Array([0.00106293], dtype=float32), 'mean': Array([ 0.00045585, -0.00377636], dtype=float32), 'std': Array([-0.00517121, -0.00102045], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00061157, 0.00059727], dtype=float32), 'log_alpha_scale': Array([0.00080996, 0.00020681], dtype=float32), 'log_penalty_temperature': Array([0.00090843], dtype=float32), 'log_temperature': Array([-0.00921247], dtype=float32), 'mean': Array([-0.18599452, 0.12220959], dtype=float32), 'std': Array([1.3813591, 1.4182489], dtype=float32)}\n",
"\t Slowdist: (Array([-0.66956097, -0.57816875], dtype=float32), Array([0.4727344, 0.5537561], dtype=float32))\n",
"Iteration 2300\n",
"\t Loss: 4.692106246948242\n",
"\t Params: {'log_alpha_mean': Array([-2.8534713, -2.8690355], dtype=float32), 'log_alpha_scale': Array([2.0471725 , 0.13146144], dtype=float32), 'log_penalty_temperature': Array([11.408697], dtype=float32), 'log_temperature': Array([0.8000543], dtype=float32), 'mean': Array([-0.6372059 , -0.64016896], dtype=float32), 'std': Array([0.37509105, 0.5431576 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(2.0379157, dtype=float32), 'loss_dist_scale': Array(2.0379157, dtype=float32), 'loss_penalty_temperature': Array(-0.07064232, dtype=float32), 'loss_temperature': Array(0.6828854, dtype=float32), 'non_parametric_kl': Array(0.07788844, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00010004, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00106187, -0.0012204 ], dtype=float32), 'log_alpha_scale': Array([0.01081586, 0.00433503], dtype=float32), 'log_penalty_temperature': Array([-0.0103283], dtype=float32), 'log_temperature': Array([-0.00019934], dtype=float32), 'mean': Array([0.00196795, 0.00243106], dtype=float32), 'std': Array([-0.00032353, -0.00163485], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00054557, 0.00053768], dtype=float32), 'log_alpha_scale': Array([0.00088456, 0.00053174], dtype=float32), 'log_penalty_temperature': Array([0.00089996], dtype=float32), 'log_temperature': Array([0.01525766], dtype=float32), 'mean': Array([ 0.53430986, -0.4800687 ], dtype=float32), 'std': Array([1.5927801, 0.6046827], dtype=float32)}\n",
"\t Slowdist: (Array([-0.63917387, -0.6426 ], dtype=float32), Array([0.37541458, 0.5447924 ], dtype=float32))\n",
"Iteration 2400\n",
"\t Loss: 4.198826313018799\n",
"\t Params: {'log_alpha_mean': Array([-2.9525423, -2.980736 ], dtype=float32), 'log_alpha_scale': Array([2.4084244 , 0.80146295], dtype=float32), 'log_penalty_temperature': Array([10.396122], dtype=float32), 'log_temperature': Array([0.8096197], dtype=float32), 'mean': Array([-0.6637495 , -0.69964975], dtype=float32), 'std': Array([0.33132267, 0.4892195 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(1.7508824, dtype=float32), 'loss_dist_scale': Array(1.7508824, dtype=float32), 'loss_penalty_temperature': Array(-0.08890528, dtype=float32), 'loss_temperature': Array(0.78130347, dtype=float32), 'non_parametric_kl': Array(0.11710778, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00016188, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00107273, -0.00108827], dtype=float32), 'log_alpha_scale': Array([0.00585714, 0.00444589], dtype=float32), 'log_penalty_temperature': Array([-0.0101672], dtype=float32), 'log_temperature': Array([6.6331086e-06], dtype=float32), 'mean': Array([-0.00210207, -0.00128249], dtype=float32), 'std': Array([-0.00226152, -0.00476074], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00049667, 0.00048354], dtype=float32), 'log_alpha_scale': Array([0.00091702, 0.00068934], dtype=float32), 'log_penalty_temperature': Array([0.00083809], dtype=float32), 'log_temperature': Array([-0.01183898], dtype=float32), 'mean': Array([-0.45073563, 0.2154806 ], dtype=float32), 'std': Array([0.71651936, 0.80619526], dtype=float32)}\n",
"\t Slowdist: (Array([-0.66164744, -0.69836724], dtype=float32), Array([0.3335842 , 0.49398023], dtype=float32))\n",
"Iteration 2500\n",
"\t Loss: 3.302090883255005\n",
"\t Params: {'log_alpha_mean': Array([-3.0564716, -3.092617 ], dtype=float32), 'log_alpha_scale': Array([2.786175 , 1.3331501], dtype=float32), 'log_penalty_temperature': Array([9.360169], dtype=float32), 'log_temperature': Array([0.80660987], dtype=float32), 'mean': Array([-0.66668504, -0.6842781 ], dtype=float32), 'std': Array([0.2908691, 0.4426876], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(1.2415444, dtype=float32), 'loss_dist_scale': Array(1.2415444, dtype=float32), 'loss_penalty_temperature': Array(-0.06395072, dtype=float32), 'loss_temperature': Array(0.8776398, dtype=float32), 'non_parametric_kl': Array(0.09520049, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00011058, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00091714, -0.00101344], dtype=float32), 'log_alpha_scale': Array([0.00414436, 0.00073057], dtype=float32), 'log_penalty_temperature': Array([-0.01050976], dtype=float32), 'log_temperature': Array([-0.00016694], dtype=float32), 'mean': Array([-0.00357646, 0.00386907], dtype=float32), 'std': Array([ 7.3645823e-04, -5.2013595e-05], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00044978, 0.00043455], dtype=float32), 'log_alpha_scale': Array([0.0009417 , 0.00079124], dtype=float32), 'log_penalty_temperature': Array([0.00088935], dtype=float32), 'log_temperature': Array([0.0033185], dtype=float32), 'mean': Array([-0.04133159, 0.34382966], dtype=float32), 'std': Array([0.34080172, 1.266612 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6631086, -0.6881471], dtype=float32), Array([0.29013264, 0.4427396 ], dtype=float32))\n",
"Iteration 2600\n",
"\t Loss: 2.4357240200042725\n",
"\t Params: {'log_alpha_mean': Array([-3.1564276, -3.1977131], dtype=float32), 'log_alpha_scale': Array([3.1509507, 1.7455796], dtype=float32), 'log_penalty_temperature': Array([8.325622], dtype=float32), 'log_temperature': Array([0.7955268], dtype=float32), 'mean': Array([-0.70309126, -0.6842369 ], dtype=float32), 'std': Array([0.25166315, 0.37185827], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(0.71509874, dtype=float32), 'loss_dist_scale': Array(0.71509874, dtype=float32), 'loss_penalty_temperature': Array(-0.06664184, dtype=float32), 'loss_temperature': Array(1.0662508, dtype=float32), 'non_parametric_kl': Array(0.11403795, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00012324, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00098323, -0.00098954], dtype=float32), 'log_alpha_scale': Array([ 0.0023826 , -0.00260392], dtype=float32), 'log_penalty_temperature': Array([-0.0104158], dtype=float32), 'log_temperature': Array([0.00047265], dtype=float32), 'mean': Array([-0.00741834, -0.00101 ], dtype=float32), 'std': Array([-0.00091901, -0.00624351], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00040877, 0.00039289], dtype=float32), 'log_alpha_scale': Array([0.00095885, 0.00085172], dtype=float32), 'log_penalty_temperature': Array([0.00087653], dtype=float32), 'log_temperature': Array([-0.00967095], dtype=float32), 'mean': Array([1.5016037, 0.9584311], dtype=float32), 'std': Array([0.4960642, 1.0620732], dtype=float32)}\n",
"\t Slowdist: (Array([-0.6956729, -0.6832269], dtype=float32), Array([0.25258216, 0.37810177], dtype=float32))\n",
"Iteration 2700\n",
"\t Loss: 1.2465910911560059\n",
"\t Params: {'log_alpha_mean': Array([-3.2478104, -3.297833 ], dtype=float32), 'log_alpha_scale': Array([3.3466885, 2.1322896], dtype=float32), 'log_penalty_temperature': Array([7.2462516], dtype=float32), 'log_temperature': Array([0.7884901], dtype=float32), 'mean': Array([-0.73617065, -0.69644904], dtype=float32), 'std': Array([0.23319091, 0.31601974], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(0.06008823, dtype=float32), 'loss_dist_scale': Array(0.06008823, dtype=float32), 'loss_penalty_temperature': Array(-0.02687555, dtype=float32), 'loss_temperature': Array(1.146919, dtype=float32), 'non_parametric_kl': Array(0.09441743, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(8.500862e-05, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00103505, -0.00099401], dtype=float32), 'log_alpha_scale': Array([-0.00123729, 0.00064391], dtype=float32), 'log_penalty_temperature': Array([-0.01070267], dtype=float32), 'log_temperature': Array([0.0001058], dtype=float32), 'mean': Array([0.00441471, 0.00168713], dtype=float32), 'std': Array([-0.00428923, -0.00169005], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00037443, 0.0003568 ], dtype=float32), 'log_alpha_scale': Array([0.00096604, 0.00089394], dtype=float32), 'log_penalty_temperature': Array([0.00091434], dtype=float32), 'log_temperature': Array([0.00383787], dtype=float32), 'mean': Array([-0.4347672 , 0.06996777], dtype=float32), 'std': Array([1.3900118, 1.7754488], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7405854 , -0.69813615], dtype=float32), Array([0.23748013, 0.3177098 ], dtype=float32))\n",
"Iteration 2800\n",
"\t Loss: 0.8522613644599915\n",
"\t Params: {'log_alpha_mean': Array([-3.333727 , -3.3931324], dtype=float32), 'log_alpha_scale': Array([3.7861686, 2.6131463], dtype=float32), 'log_penalty_temperature': Array([6.194662], dtype=float32), 'log_temperature': Array([0.78711265], dtype=float32), 'mean': Array([-0.7609844, -0.745074 ], dtype=float32), 'std': Array([0.23116761, 0.2870624 ], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(-0.11281026, dtype=float32), 'loss_dist_scale': Array(-0.11281026, dtype=float32), 'loss_penalty_temperature': Array(-0.05427381, dtype=float32), 'loss_temperature': Array(1.1249924, dtype=float32), 'non_parametric_kl': Array(0.10725211, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(0.00014902, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00102781, -0.00092125], dtype=float32), 'log_alpha_scale': Array([0.00924895, 0.0017472 ], dtype=float32), 'log_penalty_temperature': Array([-0.01039724], dtype=float32), 'log_temperature': Array([-0.00022391], dtype=float32), 'mean': Array([-0.00220119, -0.00133538], dtype=float32), 'std': Array([ 0.00100698, -0.00305621], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00034466, 0.0003254 ], dtype=float32), 'log_alpha_scale': Array([0.00097762, 0.00093159], dtype=float32), 'log_penalty_temperature': Array([0.00084926], dtype=float32), 'log_temperature': Array([-0.00498392], dtype=float32), 'mean': Array([0.9991412, 1.3850368], dtype=float32), 'std': Array([0.75854874, 1.988039 ], dtype=float32)}\n",
"\t Slowdist: (Array([-0.7587832, -0.7437386], dtype=float32), Array([0.23016064, 0.2901186 ], dtype=float32))\n",
"Iteration 2900\n",
"\t Loss: -0.18246911466121674\n",
"\t Params: {'log_alpha_mean': Array([-3.4260716, -3.4798765], dtype=float32), 'log_alpha_scale': Array([4.249098 , 3.2621253], dtype=float32), 'log_penalty_temperature': Array([5.1533737], dtype=float32), 'log_temperature': Array([0.7354052], dtype=float32), 'mean': Array([-0.80483425, -0.75524163], dtype=float32), 'std': Array([0.20568475, 0.24975947], dtype=float32)}\n",
"\t Metrics: {'loss_dist_mean': Array(-0.69869435, dtype=float32), 'loss_dist_scale': Array(-0.69869435, dtype=float32), 'loss_penalty_temperature': Array(-0.02461838, dtype=float32), 'loss_temperature': Array(1.2313594, dtype=float32), 'non_parametric_kl': Array(0.08800526, dtype=float32), 'parametric_kl_mean': Array(0., dtype=float32), 'parametric_kl_scale': Array(0., dtype=float32), 'penalty_non_parametric_kl': Array(8.6738524e-05, dtype=float32)}\n",
"\t Updates: {'log_alpha_mean': Array([-0.00092478, -0.00081869], dtype=float32), 'log_alpha_scale': Array([0.00459711, 0.00396084], dtype=float32), 'log_penalty_temperature': Array([-0.01030952], dtype=float32), 'log_temperature': Array([-0.00095799], dtype=float32), 'mean': Array([-0.00298282, 0.00087718], dtype=float32), 'std': Array([ 0.00244214, -0.00095704], dtype=float32)}\n",
"\t Grad: {'log_alpha_mean': Array([0.00031519, 0.00029914], dtype=float32), 'log_alpha_scale': Array([0.00098586, 0.00096297], dtype=float32), 'log_penalty_temperature': Array([0.000908], dtype=float32), 'log_temperature': Array([0.00811102], dtype=float32), 'mean': Array([ 1.2864131 , -0.31208917], dtype=float32), 'std': Array([1.6993513, 1.9797974], dtype=float32)}\n",
"\t Slowdist: (Array([-0.80185145, -0.75611883], dtype=float32), Array([0.20324261, 0.2507165 ], dtype=float32))\n"
]
}
],
"source": [
"\n",
"optimizer = optax.adam(1e-2)\n",
"key = jax.random.PRNGKey(41)\n",
"target = jnp.array([-0.75, -0.75])\n",
"n_act_samples = 100\n",
"std = 0.3\n",
"slowdist_update_freq = 25\n",
"objective_multiplier = 1.\n",
"objective_constraint = True\n",
"slowdist_constraint = True\n",
"penalty_constraint = True\n",
"\n",
"@jax.jit\n",
"def loss_fn(params, key, slowdist):\n",
" metrics = {}\n",
" mean = params['mean']\n",
" std = params['std']\n",
" temperature = jax.nn.softplus(params['log_temperature']) + 1e-8\n",
" penalty_temperature = jax.nn.softplus(params['log_penalty_temperature']) + 1e-8\n",
" alpha_mean = jax.nn.softplus(params['log_alpha_mean']) + 1e-8\n",
" alpha_scale = jax.nn.softplus(params['log_alpha_scale']) + 1e-8\n",
"\n",
" dist = tfd.Independent(tfd.Normal(mean, std), reinterpreted_batch_ndims=1)\n",
" a_improvement = slowdist.sample(n_act_samples, seed=key)\n",
" q_improvement = jax.lax.stop_gradient(objective_multiplier * jnp.sum(jnp.exp(-30 * (a_improvement - target)**2), -1))\n",
"\n",
"\n",
" def compute_weights_and_temperature_loss(q_values, epsilon, temperature):\n",
" tempered_q_values = jax.lax.stop_gradient(q_values) / temperature\n",
" q_logsumexp = jnp.mean(jnp.log(jnp.mean(jnp.exp(tempered_q_values), axis=0)))\n",
" loss_temperature = jnp.mean(temperature * epsilon + temperature * q_logsumexp)\n",
" normalized_weights = jax.lax.stop_gradient(jax.nn.softmax(tempered_q_values, axis=0))\n",
" num_action_samples = normalized_weights.shape[0]\n",
" integrand = jnp.log(num_action_samples * normalized_weights + 1e-8)\n",
" non_parametric_kl = jnp.sum(normalized_weights * integrand, axis=0)\n",
" return normalized_weights, loss_temperature, non_parametric_kl\n",
"\n",
" objective_normalized_weights, objective_loss_temperature, objective_non_parametric_kl = compute_weights_and_temperature_loss(q_improvement, 0.1, temperature)\n",
" metrics['loss_temperature'] = objective_loss_temperature\n",
" metrics['non_parametric_kl'] = objective_non_parametric_kl\n",
"\n",
" diff_out_of_bound = a_improvement - jnp.clip(a_improvement, -1.0, 1.0)\n",
" cost_out_of_bound = -jnp.linalg.norm(diff_out_of_bound, axis=-1)\n",
" penalty_normalized_weights, loss_penalty_temperature, penalty_non_parametric_kl = compute_weights_and_temperature_loss(cost_out_of_bound, 0.001, penalty_temperature)\n",
" metrics['loss_penalty_temperature'] = loss_penalty_temperature\n",
" metrics['penalty_non_parametric_kl'] = penalty_non_parametric_kl\n",
"\n",
" normalized_weights = jnp.zeros_like(q_improvement)\n",
" loss_temperature = 0.\n",
" if objective_constraint:\n",
" loss_temperature += objective_loss_temperature\n",
" normalized_weights += objective_normalized_weights\n",
" if penalty_constraint:\n",
" loss_temperature += loss_penalty_temperature\n",
" normalized_weights += penalty_normalized_weights\n",
"\n",
" dist_fixed_scale = tfd.Independent(tfd.Normal(\n",
" loc=dist.distribution.mean(),\n",
" scale=jax.lax.stop_gradient(slowdist.distribution.stddev())), 1)\n",
" dist_fixed_mean = tfd.Independent(tfd.Normal(\n",
" loc=jax.lax.stop_gradient(slowdist.distribution.mean()),\n",
" scale=dist.distribution.stddev()), 1)\n",
"\n",
" def compute_parametric_kl_penalty_and_dual_loss(kl, alpha, epsilon):\n",
" loss_kl = jnp.sum(jax.lax.stop_gradient(alpha) * kl, -1)\n",
" loss_alpha = jnp.sum(alpha * (epsilon - jax.lax.stop_gradient(kl)), -1)\n",
" return loss_kl, loss_alpha\n",
"\n",
" kl_mean = slowdist.distribution.kl_divergence(dist_fixed_scale.distribution)\n",
" loss_kl_mean, loss_alpha_mean= compute_parametric_kl_penalty_and_dual_loss(kl_mean, alpha_mean, 0.01)\n",
" metrics['parametric_kl_mean'] = jnp.mean(kl_mean)\n",
" kl_scale = slowdist.distribution.kl_divergence(dist_fixed_mean.distribution)\n",
" loss_kl_scale, loss_alpha_scale = compute_parametric_kl_penalty_and_dual_loss(kl_scale, alpha_scale, 0.001)\n",
" metrics['parametric_kl_scale'] = jnp.mean(kl_scale)\n",
"\n",
" def compute_cross_entropy_loss(distribution, sampled_actions, normalized_weights):\n",
" logpi = distribution.log_prob(jax.lax.stop_gradient(sampled_actions))\n",
" return jnp.mean(-jnp.sum(normalized_weights * logpi, axis=0))\n",
"\n",
" loss_dist_mean = compute_cross_entropy_loss(dist_fixed_scale, a_improvement, normalized_weights)\n",
" metrics['loss_dist_mean'] = loss_dist_mean\n",
" loss_dist_scale = compute_cross_entropy_loss(dist_fixed_mean, a_improvement, normalized_weights)\n",
" metrics['loss_dist_scale'] = loss_dist_scale\n",
"\n",
" loss_dist = loss_dist_mean + loss_dist_scale\n",
" loss = loss_dist + loss_temperature\n",
" if slowdist_constraint:\n",
" loss += loss_kl_mean + loss_kl_scale\n",
" loss += loss_alpha_mean + loss_alpha_scale\n",
"\n",
" return loss, metrics\n",
"\n",
"\n",
"params = {\n",
" 'mean': jnp.array([2.0, 2.0]),\n",
" 'std': jnp.array([1.0, 1.0]),\n",
" 'log_temperature': jnp.array([10.]),\n",
" 'log_penalty_temperature': jnp.array([10.]),\n",
" 'log_alpha_mean': jnp.array([10., 10.]),\n",
" 'log_alpha_scale': jnp.array([10., 10.]),}\n",
"opt_state = optimizer.init(params)\n",
"slowdist = tfd.Independent(tfd.Normal(\n",
" jax.lax.stop_gradient(params['mean']),\n",
" jax.lax.stop_gradient(std * jnp.ones_like(params['mean']))), 1)\n",
"\n",
"means = []\n",
"stds = []\n",
"log_temperatures = []\n",
"for i in range(3000):\n",
" if i % slowdist_update_freq == 0:\n",
" slowdist = tfd.Independent(tfd.Normal(\n",
" jax.lax.stop_gradient(params['mean']),\n",
" jax.lax.stop_gradient(params['std'])), 1)\n",
" _, key = jax.random.split(key)\n",
" (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, key, slowdist)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" means.append(params['mean'])\n",
" stds.append(params['std'])\n",
" log_temperatures.append(params['log_temperature'])\n",
" if i % 100 == 0:\n",
" print(f'Iteration {i}')\n",
" print(f'\\t Loss: {loss}')\n",
" print(f'\\t Params: {params}')\n",
" print(f'\\t Metrics: {metrics}')\n",
" print(f'\\t Updates: {updates}')\n",
" print(f'\\t Grad: {grads}')\n",
" print(f'\\t Slowdist: {slowdist.mean(), slowdist.stddev()}')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib import patches\n",
"a = np.array(means)\n",
"plt.scatter(a[:, 0], a[:, 1])\n",
"plt.scatter(a[0, 0], a[0, 1], color='r')\n",
"\n",
"if objective_constraint:\n",
" plt.scatter(target[0], target[1], color='#39FF14')\n",
"\n",
"if penalty_constraint:\n",
" plt.plot([-1, 1], [1, 1], color='r')\n",
" plt.plot([-1, 1], [-1, -1], color='r')\n",
" plt.plot([-1, -1], [-1, 1], color='r')\n",
" plt.plot([1, 1], [-1, 1], color='r')\n",
"\n",
"for i in range(0, len(means), 300):\n",
" curr_mean = means[i]\n",
" curr_std = stds[i]\n",
" ell = patches.Ellipse(curr_mean, curr_std[0], curr_std[1], color='r', fill=False)\n",
" plt.gca().add_patch(ell)\n",
"\n",
"plt.gca().set_aspect('equal', adjustable='box')\n",
"plt.title(f'{len(means)} iterations')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f99580df970>,\n",
" <matplotlib.lines.Line2D at 0x7f993849cc70>]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(stds)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(jax.nn.softplus(jnp.array(log_temperatures_10)), label='10')\n",
"plt.plot(jax.nn.softplus(jnp.array(log_temperatures_1)), label='1')\n",
"plt.plot(jax.nn.softplus(jnp.array(log_temperatures_0p1)), label='0.1')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "awake",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment