Skip to content

Instantly share code, notes, and snippets.

@dominicrufa
Created July 1, 2021 05:39
Show Gist options
  • Save dominicrufa/6dcf8dcc2b62f2d7edea6331fba7910f to your computer and use it in GitHub Desktop.
Save dominicrufa/6dcf8dcc2b62f2d7edea6331fba7910f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "47793b2b-3c77-4e34-9f19-bba8d1465d1b",
"metadata": {},
"source": [
"# Annealed Flow Transport Monte Carlo"
]
},
{
"cell_type": "markdown",
"id": "2d5ce5d2-51d2-4850-ba20-2f0bdbeeb140",
"metadata": {},
"source": [
"I'm going to try to write up a toy example of annealed flow transport monte carlo..."
]
},
{
"cell_type": "code",
"execution_count": 181,
"id": "a278cbd7-a082-42cb-9689-3b53d7b7b49d",
"metadata": {},
"outputs": [],
"source": [
"from typing import Sequence, Callable, Dict, Tuple, Optional\n",
"import jax\n",
"import flax.linen as nn\n",
"import jax.numpy as jnp\n",
"from functools import partial\n",
"from jax import lax, ops, vmap, jit, grad, random\n",
"from jax.scipy.special import logsumexp\n",
"\n",
"from jax.config import config\n",
"from typing import Callable\n",
"\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
"Conf = Params = Array = Seed = jnp.array\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9352b0e6-d28c-410f-b4d0-3cdb55cc9ad0",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "8cf47625-c949-42a9-ac58-0182f1ada847",
"metadata": {},
"source": [
"## defining simple distributions to anneal from/to"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8ab9efef-daaa-4ade-9c4b-dc188eba7987",
"metadata": {},
"outputs": [],
"source": [
"def unnormalized_Normal_logp(x: Conf,\n",
" mu: Conf, \n",
" cov: Conf) -> Conf:\n",
" \"\"\"\n",
" compute an unnormalized gaussian logp\n",
" arguments\n",
" x : jnp.array(Dx)\n",
" position\n",
" mu : jnp.array(Dx)\n",
" mean vector\n",
" cov : jnp.array(Dx)\n",
" covariance vector\n",
" returns\n",
" out : float\n",
" unnormalized gaussian logp\n",
" \"\"\"\n",
" delta = x-mu\n",
" return -0.5*(delta/cov).dot(delta)\n",
"\n",
"def Normal_energy(x: Conf,\n",
" mu: Conf, \n",
" cov: Conf) -> Conf:\n",
" \n",
" return -unnormalized_Normal_logp(x, mu, cov)\n",
"\n",
"def unnormalized_gmm_logp(x: Conf, \n",
" mus: Conf, \n",
" covs: Conf, \n",
" lws: Conf) -> Conf:\n",
" \"\"\"\n",
" return unnormalized gaussian mixture model logp\n",
" \"\"\"\n",
" dim = len(x)\n",
" def mapper(entry):\n",
" _mu, _cov = entry[:dim], entry[dim:]\n",
" return unnormalized_Normal_logp(x, _mu, _cov)\n",
" \n",
" unnorm_logps = lax.map(mapper, jnp.hstack((mus, covs)))\n",
" weighted_ps = jnp.exp(lws + unnorm_logps).sum()\n",
" return jnp.log(weighted_ps)\n",
"\n",
"def gmm_energy(x: Conf, \n",
" mus: Conf, \n",
" covs: Conf, \n",
" lws: Conf) -> Conf:\n",
" return -unnormalized_gmm_logp(x, mus, covs, lws)\n",
" \n",
"\n",
"#samplers\n",
"def sample_normal(seed: Conf, \n",
" N: int, \n",
" mu: Conf, \n",
" cov: Conf) -> Conf:\n",
" \"\"\"\n",
" sample a normal distribution\n",
" \"\"\"\n",
" dim = len(mu)\n",
" return random.normal(seed, shape=(N,dim)) * jnp.sqrt(cov) + mu\n",
"\n",
"def sample_gmm(seed: Conf, \n",
" mus: Conf, \n",
" covs: Conf, \n",
" lws: Conf) -> Conf:\n",
" from jax.scipy.special import logsumexp\n",
" num_mixtures = len(lws)\n",
" dim = len(mu)\n",
" seed1, seed2 = random.split(seed)\n",
" mixture_idx = random.choice(seed1, len(lws), p=jnp.exp(lws - logsumexp(lws)))\n",
" return random.normal(seed2, shape=(dim,)) * jnp.sqrt(covs[mixture_idx]) + mus[mixture_idx]\n",
"\n",
"def kinetic_energy(m : Array, v: Array) -> float:\n",
" return 0.5 * v.dot(v) / m\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "be1e30ad-ad39-4afa-af03-426ba2a1a51b",
"metadata": {},
"outputs": [],
"source": [
"def free_energy(works):\n",
" \"\"\"\n",
" compute the free energy from a work array\n",
" \"\"\"\n",
" from jax.scipy.special import logsumexp\n",
" N = len(works)\n",
" w_min = jnp.min(works)\n",
" return w_min - logsumexp(-works + w_min) + jnp.log(N)\n",
"\n",
"def ESS(works):\n",
" log_weights = -works\n",
" Ws = jnp.exp(log_weights - logsumexp(log_weights))\n",
" ESS = 1. / jnp.sum(Ws**2) / len(works)\n",
" return ESS"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d8693bb0-cc66-49ef-b443-82ce98d31ea4",
"metadata": {},
"outputs": [],
"source": [
"def BAOAB_kernel(X: Conf,\n",
" seed: Seed,\n",
" energy_parameters: Params,\n",
" energy_fn : Callable[[Conf, Params], float],\n",
" dt: float,\n",
" gamma: float,\n",
" mass: Params) -> Conf:\n",
" \"\"\"\n",
" a metropolized BAOAB kernel\n",
" \"\"\"\n",
" twice_dim = X.shape[0] #get the dimension of x, v together\n",
" single_dim = twice_dim // 2 #get the dimension of x, v separately\n",
" noise_seed, MH_seed = random.split(seed) # split the random seed\n",
" noise = random.normal(noise_seed, shape=(single_dim,)) # get the noise as white noise\n",
" a, b = jnp.exp(-gamma * dt), jnp.sqrt(1. - jnp.exp(-2. * gamma * dt)) # get the a and b parameters\n",
" r0, v0 = X[:single_dim], X[single_dim :] #pull x0, v0\n",
" e0 = energy_fn(r0, energy_parameters) + kinetic_energy(mass, v0) # get start total reduced energy...\n",
" #print(f\"v0: {v0}\")\n",
" # do splitting procedure\n",
" v1 = v0 + -grad(energy_fn)(r0, energy_parameters) * dt/mass #V\n",
" #print(f\"v1: {v1}\")\n",
" r1 = r0 + v1 * dt #R\n",
" k_b = kinetic_energy(mass, v1) #get t before O\n",
" v2 = -a * v1 + b * jnp.sqrt(1. / mass) * noise # O\n",
" #print(f\"v2: {v2}\")\n",
" k_a = kinetic_energy(mass, v2) #get t after O\n",
" r2 = r1 + v2 * dt #R\n",
" v3 = v2 + -grad(energy_fn)(r2, energy_parameters) * dt/ mass #V\n",
" #print(f\"v3: {v3}\")\n",
" \n",
" ef = energy_fn(r2, energy_parameters) + kinetic_energy(mass, v3) # get final energy\n",
" \n",
" # compute log work\n",
" lw = ef - e0 - (k_a - k_b)\n",
" log_acceptance_prob = jnp.min(Array([0., -lw]))\n",
" lu = jnp.log(random.uniform(MH_seed))\n",
" accept = (lu <= log_acceptance_prob)\n",
" #print(f\"accept: {accept}\")\n",
" out_X = lax.cond(accept, lambda x: jnp.concatenate([r2, v3]), lambda x: jnp.concatenate([r0, -v0]), None)\n",
" return out_X"
]
},
{
"cell_type": "markdown",
"id": "d0cbc9f7-a829-4ae5-a6e1-db668d81c650",
"metadata": {},
"source": [
"write a little test?"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7b533744-5a6b-49aa-acc9-671046eabd94",
"metadata": {},
"outputs": [],
"source": [
"def make_AIS_sampler(energy_fn: Callable[[Conf, Params], float], # energy function\n",
" energy_parameters: Params, # energy parameters\n",
" forward_kernel: Callable[[Conf, Seed, Params], Conf], # forward kernel \n",
" t0_sampler: Callable[Seed, Conf] # sampler at \n",
" ) -> Callable[Seed, Tuple[float, Conf, Conf]]:\n",
" \"\"\"simple function that generates a AIS sampler with the following criteria.\n",
" - this should be vmappable across N particles\n",
" \"\"\"\n",
" len_energy_parameters = energy_parameters.shape[0]\n",
" ts = jnp.arange(len_energy_parameters)[1:]\n",
" #print(f\"ts returning: {ts}\")\n",
" def execute_iteration(state: Tuple[Conf, Array, Seed], \n",
" t: float) -> Tuple[Tuple[Conf, Array, Array], Tuple[Conf, Array]]:\n",
" x_tm1, energy_tally, seed = state # separate the state\n",
" _full_dim = x_tm1.shape[0] # get the dim(x) + dim(v)\n",
" _particle_dim = _full_dim // 2 # get dim(x), dim(v) (same)\n",
" run_seed, next_seed = random.split(seed) # split the random seed\n",
" energy_diff = energy_fn(x_tm1[:_particle_dim], energy_parameters[t]) - energy_fn(x_tm1[: _particle_dim], energy_parameters[t-1]) # compute energy difference\n",
" x_t = forward_kernel(x_tm1, run_seed, energy_parameters[t]) # propagate the particle forward\n",
" new_energy_tally = energy_tally + energy_diff # update the energy tally\n",
" new_state = (x_t, new_energy_tally, next_seed) # wrap the new state in a tuple\n",
" return new_state, new_state[:2] # return the new state\n",
" \n",
" \n",
" def AIS_sampler(seed: Seed) -> Tuple[float, Conf, Conf]: \n",
" init_seed, seed = random.split(seed) # split the seed\n",
" x0 = t0_sampler(init_seed) # sample [x,v] @ t0\n",
" init = (x0, 0., seed) # initialize x0 with zero work\n",
" (x_T, cumulative_work, final_seed), (traj, work_traj) = lax.scan(execute_iteration, init, ts)\n",
" out = (cumulative_work, jnp.vstack([x0[jnp.newaxis, ...], traj]), work_traj)\n",
" return out \n",
" \n",
" # jit the new sampler\n",
" jAIS_sampler = jit(AIS_sampler)\n",
" \n",
" return jAIS_sampler"
]
},
{
"cell_type": "markdown",
"id": "b6743fb0-d0a0-4997-b29c-cb58c4b74c37",
"metadata": {},
"source": [
"1D test"
]
},
{
"cell_type": "code",
"execution_count": 407,
"id": "5a9889d8-c955-4363-8768-5da43428cc62",
"metadata": {},
"outputs": [],
"source": [
"def energy_fn(x, param):\n",
" _param = param[0]\n",
" out = Normal_energy(x, jnp.array([0.]), _param)\n",
" return out "
]
},
{
"cell_type": "code",
"execution_count": 408,
"id": "e099eeeb-6a2b-4bfb-859a-9fa033848dc1",
"metadata": {},
"outputs": [],
"source": [
"def _t0_sampler(seed):\n",
" xseed, vseed = random.split(seed)\n",
" x = sample_normal(xseed, \n",
" 1, \n",
" jnp.zeros((1,)), \n",
" jnp.ones((1,))).flatten()\n",
" v = random.normal(vseed, shape=(1,))\n",
" return jnp.concatenate([x, v])"
]
},
{
"cell_type": "code",
"execution_count": 409,
"id": "3d3343d2-6426-4475-82b6-a0beeb9cb196",
"metadata": {},
"outputs": [],
"source": [
"_forward_kernel = partial(BAOAB_kernel, energy_fn = energy_fn, dt=1e-1, gamma=1., mass=1.)"
]
},
{
"cell_type": "code",
"execution_count": 410,
"id": "9b930c2f-f99f-4d86-a758-5f48062cf742",
"metadata": {},
"outputs": [],
"source": [
"energy_parameters = jnp.linspace(1., 2., 10)[..., jnp.newaxis]"
]
},
{
"cell_type": "code",
"execution_count": 411,
"id": "12b01f37-c4e2-42d9-ab2a-e375f20eacea",
"metadata": {},
"outputs": [],
"source": [
"ais_sampler = make_AIS_sampler(energy_fn, energy_parameters, _forward_kernel, _t0_sampler)"
]
},
{
"cell_type": "code",
"execution_count": 412,
"id": "79f1302a-a0c9-495a-8a2b-d1ed79163a7e",
"metadata": {},
"outputs": [],
"source": [
"vais_sampler = vmap(ais_sampler)"
]
},
{
"cell_type": "code",
"execution_count": 416,
"id": "11704c8f-2aed-47b1-aff4-03adfc34b948",
"metadata": {},
"outputs": [],
"source": [
"key = random.PRNGKey(33)\n",
"all_keys = random.split(key, num=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0cdcb81-b610-4add-a71e-f21c6c2f3fe9",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 417,
"id": "c476beac-63f4-4e83-8db2-b5577904f646",
"metadata": {},
"outputs": [],
"source": [
"(lws, trajs, lws_trajs) = vais_sampler(all_keys)"
]
},
{
"cell_type": "code",
"execution_count": 425,
"id": "01045110-1024-4b67-ad42-6db3335feb3f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 2., 0., 2., 3., 3., 3., 8., 6., 7., 66.]),\n",
" array([-1.68757860e+00, -1.51883749e+00, -1.35009639e+00, -1.18135529e+00,\n",
" -1.01261419e+00, -8.43873084e-01, -6.75131981e-01, -5.06390878e-01,\n",
" -3.37649776e-01, -1.68908673e-01, -1.67570249e-04]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 425,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOHElEQVR4nO3dYYxlZ13H8e/PLg1awHbt7GakhBGzqTQmtM2k1tSQ6FJSW8MuL2raRJ2YJhsSMG2iMau+wXetiURNDMlK0VERrUCzG0BgHWmICVZmS1taF1hKSqkddocqUnwBKfx9MWfLMHt37pmZe+/cR76f5Oac89xz7vPv6dNfzzxzz5lUFZKk9vzIbhcgSdoeA1ySGmWAS1KjDHBJapQBLkmN2jPJzq688sqam5ubZJeS1LxTp059vapmNrZPNMDn5uZYXl6eZJeS1LwkXxnU7hSKJDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1aqJ3YkrSbpo7+pFd6/vpe28b+Wd6BS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGtUrwJNcnuQDST6f5HSSn0+yN8nJJGe65RXjLlaS9H19r8D/FPhYVf0M8AbgNHAUWKqqA8BSty1JmpChAZ7kVcAbgfsBquo7VfUN4BCw2O22CBweT4mSpEH6XIG/DlgF/jLJZ5O8J8llwP6qWgHolvvGWKckaYM+Ab4HuB54d1VdB/wvW5guSXIkyXKS5dXV1W2WKUnaqE+APws8W1UPd9sfYC3QzyaZBeiW5wYdXFXHqmq+quZnZmZGUbMkiR4BXlVfA76a5Oqu6SDwH8AJYKFrWwCOj6VCSdJAff8q/W8B70tyKfBl4DdZC/8HktwFPAPcPp4SJUmD9ArwqnoUmB/w1sGRViNJ6s07MSWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElq1J4+OyV5GngB+C7wYlXNJ9kL/AMwBzwN/GpV/fd4ypQkbbSVK/BfrKprq2q+2z4KLFXVAWCp25YkTchOplAOAYvd+iJweMfVSJJ66xvgBXwiyakkR7q2/VW1AtAt9w06MMmRJMtJlldXV3desSQJ6DkHDtxUVc8l2QecTPL5vh1U1THgGMD8/Hxto0ZJ0gC9rsCr6rlueQ54ELgBOJtkFqBbnhtXkZKkCw0N8CSXJXnl+XXgzcATwAlgodttATg+riIlSRfqM4WyH3gwyfn9/66qPpbkM8ADSe4CngFuH1+ZkqSNhgZ4VX0ZeMOA9ueBg+MoSpI0nHdiSlKjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRvUO8CSXJPlskg9323uTnExyplteMb4yJUkbbeUK/G7g9Lrto8BSVR0AlrptSdKE9ArwJFcBtwHvWdd8CFjs1heBwyOtTJK0qb5X4H8C/C7wvXVt+6tqBaBb7ht0YJIjSZaTLK+uru6kVknSOkMDPMmvAOeq6tR2OqiqY1U1X1XzMzMz2/kISdIAe3rscxPwliS3Ai8HXpXkb4GzSWaraiXJLHBunIVKkn7Q0Cvwqvq9qrqqquaAO4B/qapfA04AC91uC8DxsVUpSbrATr4Hfi9wc5IzwM3dtiRpQvpMobykqh4CHurWnwcOjr4kSVIf3okpSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqOGBniSlyf59ySPJXkyyR927XuTnExyplteMf5yJUnn9bkC/zbwS1X1BuBa4JYkNwJHgaWqOgAsdduSpAkZGuC15lvd5su6VwGHgMWufRE4PI4CJUmD9ZoDT3JJkkeBc8DJqnoY2F9VKwDdct9Fjj2SZDnJ8urq6ojKliT1CvCq+m5VXQtcBdyQ5Gf7dlBVx6pqvqrmZ2ZmtlmmJGmjLX0Lpaq+ATwE3AKcTTIL0C3Pjbo4SdLF9fkWykySy7v1HwXeBHweOAEsdLstAMfHVKMkaYA9PfaZBRaTXMJa4D9QVR9O8mnggSR3Ac8At4+xTknSBkMDvKoeB64b0P48cHAcRUmShvNOTElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1KihAZ7kNUk+meR0kieT3N21701yMsmZbnnF+MuVJJ3X5wr8ReC3q+r1wI3A25NcAxwFlqrqALDUbUuSJmRogFfVSlU90q2/AJwGXg0cAha73RaBw2OqUZI0wJbmwJPMAdcBDwP7q2oF1kIe2HeRY44kWU6yvLq6usNyJUnn9Q7wJK8APgjcU1Xf7HtcVR2rqvmqmp+ZmdlOjZKkAXoFeJKXsRbe76uqD3XNZ5PMdu/PAufGU6IkaZA+30IJcD9wuqrete6tE8BCt74AHB99eZKki9nTY5+bgF8HPpfk0a7t94F7gQeS3AU8A9w+lgolSQMNDfCq+lcgF3n74GjLkST15Z2YktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUUMDPMl7k5xL8sS6tr1JTiY50y2vGG+ZkqSN+lyB/xVwy4a2o8BSVR0AlrptSdIEDQ3wqvoU8F8bmg8Bi936InB4tGVJkobZ7hz4/qpaAeiW+y62Y5IjSZaTLK+urm6zO0nSRmP/JWZVHauq+aqan5mZGXd3kvRDY7sBfjbJLEC3PDe6kiRJfWw3wE8AC936AnB8NOVIkvrq8zXC9wOfBq5O8mySu4B7gZuTnAFu7rYlSRO0Z9gOVXXnRd46OOJaJElb4J2YktSooVfgksZr7uhHdqXfp++9bVf61eh4BS5JjTLAJalRTqFImrjdmjb6/8YrcElqlAEuSY1yCkX6IeU0Rvu8ApekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1yht5dAFv8JDa4BW4JDXKAJekRhngktSoZubAd3Nedrf+9JRz0ZI24xW4JDXKAJekRu0owJPckuQLSb6U5OioipIkDbftAE9yCfDnwC8D1wB3JrlmVIVJkja3kyvwG4AvVdWXq+o7wN8Dh0ZTliRpmJ18C+XVwFfXbT8L/NzGnZIcAY50m99K8oV1b18JfH0HNUxE7vuBzSZq3sCaJ8OaJ6PFmsl9O6r7tYMadxLgGdBWFzRUHQOODfyAZLmq5ndQw8RZ82RY82RY8+SMo+6dTKE8C7xm3fZVwHM7K0eS1NdOAvwzwIEkP5XkUuAO4MRoypIkDbPtKZSqejHJO4CPA5cA762qJ7f4MQOnVqacNU+GNU+GNU/OyOtO1QXT1pKkBngnpiQ1ygCXpEaNPcCT3J7kySTfSzLwKzRJrk7y6LrXN5Pc0733ziT/ue69W6eh5m6/p5N8rqtreV373iQnk5zplldMQ81JXpPkk0lOd/veve69aT7PAx/ZsEvneWifUziee52nKRvPfc7zVIznYY8UyZo/695/PMn1fY8dqqrG+gJeD1wNPATM99j/EuBrwGu77XcCvzPuOrdTM/A0cOWA9j8CjnbrR4H7pqFmYBa4vlt/JfBF4JppPs/deHgKeB1wKfDYupp34zxvqc8pGc+9ap6y8Ty0z2kYz5uNz3X73Ar8E2v3ztwIPNz32GGvsV+BV9XpqvrC8D1fchB4qqq+Mq6ahtlGzRsdAha79UXg8I6LGqJPzVW1UlWPdOsvAKdZu6N2V/Q8z5s9smHi53kbfe76eGbn52kqz/OUjOc+jxQ5BPx1rfk34PIksz2P3dQ0zoHfAbx/Q9s7uh893juJH9+2oIBPJDmVtUcGnLe/qlZgbZAB+3aluk0kmQOuAx5e1zyN53nQIxvO/0e6G+d5q31Ow3juW/M0ject9bmL43mz8Tlsnz7HbmokAZ7kn5M8MeC1pf+bZO2GoLcA/7iu+d3ATwPXAivAH09RzTdV1fWsPZHx7UneOIraLmaE5/kVwAeBe6rqm13ztJ7nXo9sGCXHs+N5K90PaNs4Pi+2z47H9kj+pFpVvWkUn8PawHmkqs6u++yX1pP8BfDhUXQ0ipqr6rlueS7Jg6z9SPQp4GyS2apa6X5UOrfTvrp+dlxzkpexNtjfV1UfWvfZ03qeN3tkw8TPc5Kt9DkV47lvzdM0nvvWPOnxPECfR4pcbJ9Lexy7qWmbQrmTDT9udv/yznsr8MREK7qIJJcleeX5deDNfL+2E8BCt74AHJ98hRdKEuB+4HRVvWvDe1N5ntn8kQ27cZ630ue0jOehNU/heO5T8zSM5z6PFDkB/Mbal1FyI/A/3bTQzh9HMoHf0r6Vtf8DfRs4C3y8a/9J4KPr9vsx4Hngxzcc/zfA54DHu3+42WmombXfHD/WvZ4E/mDd8T8BLAFnuuXeKan5F1j7Ee1x4NHudes0n+f6/m/xv8jab+x3+zwP7HPKx/PQmqdwPPepeSrG86DxCbwNeFu3Htb++M1TXU3zmx27lZe30ktSo6ZtCkWS1JMBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhr1fxzet6s3L/bqAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(lws)"
]
},
{
"cell_type": "code",
"execution_count": 427,
"id": "574cc79c-73c3-4dc4-be8e-845f47b295fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(-0.35377424, dtype=float64)"
]
},
"execution_count": 427,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"free_energy(lws)"
]
},
{
"cell_type": "code",
"execution_count": 428,
"id": "a32d6422-dddf-4a0f-bb9c-f876eb3bd791",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(0.76650916, dtype=float64)"
]
},
"execution_count": 428,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ESS(lws)"
]
},
{
"cell_type": "markdown",
"id": "5642aa11-4fcd-44c2-b38e-c8939078f988",
"metadata": {},
"source": [
"neato, now let's do this in 2 dimensions with a trivial free energy difference..."
]
},
{
"cell_type": "markdown",
"id": "5c5300ab-3873-4673-bc85-95eb83f56482",
"metadata": {},
"source": [
"# free energy in 2D"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "654f90b3-2c5e-4174-83c3-610f18a2d100",
"metadata": {},
"outputs": [],
"source": [
"def _energy_fn(x, param):\n",
" out = Normal_energy(x, param, jnp.ones(2))\n",
" return out "
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "6b0eab41-82d2-465c-95bc-7b268c6328a2",
"metadata": {},
"outputs": [],
"source": [
"def _t0_sampler(seed):\n",
" xseed, vseed = random.split(seed)\n",
" x = sample_normal(xseed, \n",
" 1, \n",
" jnp.zeros((2,)), \n",
" jnp.ones((2,))).flatten()\n",
" v = random.normal(vseed, shape=(2,))\n",
" return jnp.concatenate([x, v])"
]
},
{
"cell_type": "code",
"execution_count": 118,
"id": "1c405a26-f525-47a8-b4e4-a5f7b0a11140",
"metadata": {},
"outputs": [],
"source": [
"_forward_kernel = partial(BAOAB_kernel, energy_fn = _energy_fn, dt=1e-1, gamma=1., mass=1.)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "d2957a95-b3b7-4ada-aa02-04549f001f0a",
"metadata": {},
"outputs": [],
"source": [
"energy_parameters = jnp.vstack((jnp.linspace(0., 10., 30000), jnp.linspace(0., 10., 30000))).T"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "bf11f62c-6ae6-4206-8037-f5bf3806c928",
"metadata": {},
"outputs": [],
"source": [
"ais_sampler = make_AIS_sampler(_energy_fn, energy_parameters, _forward_kernel, _t0_sampler)"
]
},
{
"cell_type": "code",
"execution_count": 121,
"id": "495b553f-be87-4b50-8e03-66c34f5fa132",
"metadata": {},
"outputs": [],
"source": [
"vais_sampler = vmap(ais_sampler)"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "3b1a9a7a-a1ec-44f5-85cf-c5f1871471e0",
"metadata": {},
"outputs": [],
"source": [
"key = random.PRNGKey(350)\n",
"all_keys = random.split(key, num=50)"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "8fac4120-3515-4530-afe0-7a7a0640653f",
"metadata": {},
"outputs": [],
"source": [
"(lws, trajs, lws_trajs) = vais_sampler(all_keys)"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "8d6dfd60-8768-4b7d-a030-0b83134c092f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 5., 4., 8., 12., 13., 3., 3., 1., 0., 1.]),\n",
" array([ 1.22216554, 2.74093758, 4.25970962, 5.77848166, 7.2972537 ,\n",
" 8.81602574, 10.33479778, 11.85356982, 13.37234186, 14.8911139 ,\n",
" 16.40988594]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANMUlEQVR4nO3db4xldX3H8fenjEQXMWAYrLJsRxqktYQWMm1RUpuCJFuXgA/6ACJmW0kmaVpEY6tLSOqzZluN1cRGswFcEjeQZsVKJFo2qCVNkHZ3Qf4titEtLKI7hrRabYqbfvtgLulwd+fP3nN2zvyW9yvZzL1nzsz57vx577nn3nM2VYUkqT2/NPQAkqTJGHBJapQBl6RGGXBJapQBl6RGTa3lxs4666yamZlZy01KUvP27dv346qaHl++pgGfmZlh7969a7lJSWpekn8/1nIPoUhSowy4JDXKgEtSowy4JDXKgEtSowy4JDXKgEtSowy4JDXKgEtSo9b0TExpJTPb7h1kuwe3bxlku1IX7oFLUqMMuCQ1yoBLUqMMuCQ1yoBLUqMMuCQ1yoBLUqMMuCQ1yoBLUqMMuCQ1yoBLUqNWDHiS25McTvL4omUfS/JUkkeTfDHJGSd0SknSUVazB74T2Dy2bA9wYVVdBHwHuLnnuSRJK1gx4FX1APDC2LL7qurI6O43gY0nYDZJ0jL6OAb+PuArPXweSdJx6HQ98CS3AEeAXcusMwfMAWzatKnL5rRGhromt6TjM/EeeJKtwFXAe6qqllqvqnZU1WxVzU5PT0+6OUnSmIn2wJNsBj4C/H5V/bzfkSRJq7GalxHeCTwIXJDkUJIbgE8DpwN7kjyS5LMneE5J0pgV98Cr6rpjLL7tBMwiSToOnokpSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY1aMeBJbk9yOMnji5a9PsmeJE+P3p55YseUJI1bzR74TmDz2LJtwP1VdT5w/+i+JGkNrRjwqnoAeGFs8TXAHaPbdwDv7ncsSdJKJj0G/oaqeh5g9Pbs/kaSJK3GCX8SM8lckr1J9s7Pz5/ozUnSK8akAf9RkjcCjN4eXmrFqtpRVbNVNTs9PT3h5iRJ4yYN+D3A1tHtrcCX+hlHkrRaq3kZ4Z3Ag8AFSQ4luQHYDlyZ5GngytF9SdIamlpphaq6bol3XdHzLJKk4+CZmJLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUKAMuSY0y4JLUqE4BT/LBJE8keTzJnUle3ddgkqTlTRzwJOcA7wdmq+pC4BTg2r4GkyQtr+shlCngNUmmgA3AD7qPJElajalJP7CqnkvyceAZ4L+B+6rqvvH1kswBcwCbNm2adHOvSDPb7h16BEnrWJdDKGcC1wBvBt4EnJbk+vH1qmpHVc1W1ez09PTkk0qSXqbLIZR3At+vqvmq+gVwN/D2fsaSJK2kS8CfAS5NsiFJgCuAA/2MJUlaycQBr6qHgN3AfuCx0efa0dNckqQVTPwkJkBVfRT4aE+zSJKOg2diSlKjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjOgU8yRlJdid5KsmBJG/razBJ0vKmOn78p4CvVtUfJTkV2NDDTJKkVZg44EleB7wD+GOAqnoReLGfsSRJK+lyCOU8YB74XJKHk9ya5LTxlZLMJdmbZO/8/HyHzUmSFusS8CngEuAzVXUx8DNg2/hKVbWjqmaranZ6errD5iRJi3UJ+CHgUFU9NLq/m4WgS5LWwMQBr6ofAs8muWC06ArgyV6mkiStqOurUG4Edo1egfI94E+6jyRJWo1OAa+qR4DZfkaRJB0Pz8SUpEYZcElqlAGXpEYZcElqlAGXpEYZcElqlAGXpEYZcElqlAGXpEYZcElqVNdroayZmW33Drbtg9u3DLZtSVqKe+CS1CgDLkmNMuCS1CgDLkmNMuCS1CgDLkmNMuCS1CgDLkmNMuCS1CgDLkmNMuCS1CgDLkmN6hzwJKckeTjJl/sYSJK0On3sgd8EHOjh80iSjkOngCfZCGwBbu1nHEnSanW9HvgngQ8Dpy+1QpI5YA5g06ZNHTc3jCGvRS5JS5l4DzzJVcDhqtq33HpVtaOqZqtqdnp6etLNSZLGdDmEchlwdZKDwF3A5Uk+38tUkqQVTRzwqrq5qjZW1QxwLfC1qrq+t8kkScvydeCS1Khe/lPjqvoG8I0+PpckaXXcA5ekRhlwSWqUAZekRhlwSWqUAZekRhlwSWqUAZekRhlwSWqUAZekRhlwSWpUL6fSS617JV7z/eD2LUOPoI7cA5ekRhlwSWqUAZekRhlwSWqUAZekRhlwSWqUAZekRhlwSWqUAZekRhlwSWqUAZekRhlwSWrUxAFPcm6Sryc5kOSJJDf1OZgkaXldrkZ4BPhQVe1PcjqwL8meqnqyp9kkScuYeA+8qp6vqv2j2z8FDgDn9DWYJGl5vRwDTzIDXAw8dIz3zSXZm2Tv/Px8H5uTJNFDwJO8FvgC8IGq+sn4+6tqR1XNVtXs9PR0181JkkY6BTzJq1iI966qurufkSRJq9HlVSgBbgMOVNUn+htJkrQaXfbALwPeC1ye5JHRn3f1NJckaQUTv4ywqv4FSI+zSJKOg2diSlKjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNcqAS1KjDLgkNarL/8gjqWEz2+4dbNsHt28ZZLsn29/ZPXBJapQBl6RGGXBJapQBl6RGGXBJapQBl6RGGXBJapQBl6RGGXBJapQBl6RGGXBJapQBl6RGdQp4ks1Jvp3ku0m29TWUJGllEwc8ySnA3wN/CLwVuC7JW/saTJK0vC574L8DfLeqvldVLwJ3Adf0M5YkaSVdrgd+DvDsovuHgN8dXynJHDA3uvtfSb7dYZtdnAX8eKBtr4bzdeN83azpfPmb4/6Q5r9+E/ydF/uVYy3sEvAcY1kdtaBqB7Cjw3Z6kWRvVc0OPcdSnK8b5+vG+boZar4uh1AOAecuur8R+EG3cSRJq9Ul4P8GnJ/kzUlOBa4F7ulnLEnSSiY+hFJVR5L8OfBPwCnA7VX1RG+T9W/wwzgrcL5unK8b5+tmkPlSddRha0lSAzwTU5IaZcAlqVEndcCTnJvk60kOJHkiyU1Dz3QsSU5J8nCSLw89y7gkZyTZneSp0dfxbUPPtFiSD46+t48nuTPJq9fBTLcnOZzk8UXLXp9kT5KnR2/PXGfzfWz0PX40yReTnLGe5lv0vr9IUknOGmK20QzHnC/JjaNLizyR5G/XYpaTOuDAEeBDVfXrwKXAn63T0/1vAg4MPcQSPgV8tap+DfhN1tGcSc4B3g/MVtWFLDyZfu2wUwGwE9g8tmwbcH9VnQ/cP7o/lJ0cPd8e4MKqugj4DnDzWg+1yE6Ono8k5wJXAs+s9UBjdjI2X5I/YOFM9Iuq6jeAj6/FICd1wKvq+araP7r9Uxbic86wU71cko3AFuDWoWcZl+R1wDuA2wCq6sWq+o9BhzraFPCaJFPABtbBuQhV9QDwwtjia4A7RrfvAN69ljMtdqz5quq+qjoyuvtNFs7rGMQSXz+AvwM+zDFOGFxLS8z3p8D2qvqf0TqH12KWkzrgiyWZAS4GHhp4lHGfZOGH8n8HnuNYzgPmgc+NDvHcmuS0oYd6SVU9x8KezjPA88B/VtV9w061pDdU1fOwsGMBnD3wPMt5H/CVoYdYLMnVwHNV9a2hZ1nCW4DfS/JQkn9O8ttrsdFXRMCTvBb4AvCBqvrJ0PO8JMlVwOGq2jf0LEuYAi4BPlNVFwM/Y9iH/i8zOo58DfBm4E3AaUmuH3aqtiW5hYVDj7uGnuUlSTYAtwB/NfQsy5gCzmThUO1fAv+Q5FiXG+nVSR/wJK9iId67quruoecZcxlwdZKDLFzN8fIknx92pJc5BByqqpcetexmIejrxTuB71fVfFX9ArgbePvAMy3lR0neCDB6uyYPsY9Hkq3AVcB7an2dIPKrLPwj/a3R78pGYH+SXx50qpc7BNxdC/6VhUfUJ/yJ1pM64KN/AW8DDlTVJ4aeZ1xV3VxVG6tqhoUn375WVetmD7Kqfgg8m+SC0aIrgCcHHGncM8ClSTaMvtdXsI6eZB1zD7B1dHsr8KUBZzlKks3AR4Crq+rnQ8+zWFU9VlVnV9XM6HflEHDJ6OdzvfhH4HKAJG8BTmUNrp54UgechT3c97KwZ/vI6M+7hh6qMTcCu5I8CvwW8NfDjvP/Ro8MdgP7gcdY+Hke/JTrJHcCDwIXJDmU5AZgO3BlkqdZeCXF9nU236eB04E9o9+Tz66z+daNJea7HThv9NLCu4Cta/EoxlPpJalRJ/seuCSdtAy4JDXKgEtSowy4JDXKgEtSowy4JDXKgEtSo/4P7twGNdXPc/IAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(lws)"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "82ce4a80-70ba-4442-b232-449f40cf7613",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(50, 30000, 4)"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trajs.shape"
]
},
{
"cell_type": "code",
"execution_count": 127,
"id": "23d9e3d0-06e4-43aa-989d-03f4d255026b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f24241cfcd0>]"
]
},
"execution_count": 127,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(trajs[20,:,0])"
]
},
{
"cell_type": "markdown",
"id": "f18e3a84-9254-49c5-b163-dd22058773a8",
"metadata": {},
"source": [
"ok, this seems to be working."
]
},
{
"cell_type": "markdown",
"id": "1ac8fe13-a4bc-4a94-be71-799a44edd94e",
"metadata": {},
"source": [
"## write a deterministic, learnable sampler\n",
"the first implementation will be a singleton, 3-step process which transforms momentum and position with a jacobian of 1."
]
},
{
"cell_type": "code",
"execution_count": 134,
"id": "b75333ad-7c2f-4d38-a3f8-5d8ab44683d5",
"metadata": {},
"outputs": [],
"source": [
"from typing import Sequence, Callable\n",
"import flax.linen as nn\n",
"import flax\n",
"nnParams = flax.core.frozen_dict.FrozenDict\n"
]
},
{
"cell_type": "code",
"execution_count": 131,
"id": "ac9124f8-300f-4f36-945e-bf9628b4ced6",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 150,
"id": "63b140cd-3708-46bf-930b-ec1164a9aece",
"metadata": {},
"outputs": [],
"source": [
"def make_mod_leapfrog(seed : Seed, \n",
" features: Sequence) -> Tuple[nnParams, Callable[[Conf, nnParams], Conf]]:\n",
" \"\"\"\n",
" make a modified leapfrog function\n",
" \"\"\"\n",
" _dim = features[0] # features of positions x\n",
" model = tanhMLP(features=features) # define the model\n",
" test_x = jnp.ones(_dim) # make a test x \n",
" params = model.init(seed, test_x) # initialize the nnParams (FrozenDict)\n",
" test_y = model.apply(params, test_x) # apply the model the first time.\n",
" #print(f\"input of shape {test_x.shape} returned out of shape {test_y.shape}\")\n",
" \n",
" def leapfrog(X_in : Conf, dt: float, in_params: nnParams) -> Conf:\n",
" x_1, v_1 = X_in[:_dim], X_in[_dim:] #pull apart x, v\n",
" v_one_and_half = v_1 + dt * model.apply(in_params, x_1)\n",
" x_2 = x_1 + dt * v_one_and_half\n",
" v_2 = v_one_and_half + dt * model.apply(in_params, x_2)\n",
" return jnp.concatenate([x_2, v_2])\n",
" \n",
" return params, leapfrog\n",
" "
]
},
{
"cell_type": "markdown",
"id": "65035b9e-bb07-4774-9e7c-ddc4f1aa7a52",
"metadata": {},
"source": [
"can we test this and make sure the jacobian is actually 1?"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "362a74db-f90e-4121-9068-9fa224b77ebe",
"metadata": {},
"outputs": [],
"source": [
"nn_params, deterministic_leapfrog = make_mod_leapfrog(random.PRNGKey(3462), [2,2])"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "90c47c43-caf4-43dd-988a-4551e9339eb5",
"metadata": {},
"outputs": [],
"source": [
"X = random.normal(random.PRNGKey(2462), shape=(4,))"
]
},
{
"cell_type": "code",
"execution_count": 155,
"id": "6c88f983-781a-4977-834a-510fa016c5fc",
"metadata": {},
"outputs": [],
"source": [
"X_out = deterministic_leapfrog(X, 1e-1, nn_params)"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "c66bde8f-afe9-41ce-be5f-77eda9165661",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-0.96499554, 0.59350874, -1.38793319, 0.55717537], dtype=float64)"
]
},
"execution_count": 157,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "code",
"execution_count": 156,
"id": "fe9e1421-0112-4b46-861f-db699657f20c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-1.1180315 , 0.64786676, -1.68356186, 0.52882456], dtype=float64)"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_out"
]
},
{
"cell_type": "code",
"execution_count": 158,
"id": "2060a535-7557-4bcd-ba79-77bbf46a2ba7",
"metadata": {},
"outputs": [],
"source": [
"from jax import jacfwd"
]
},
{
"cell_type": "code",
"execution_count": 159,
"id": "00db1b97-abdf-4f31-ab4d-cde503d59f4e",
"metadata": {},
"outputs": [],
"source": [
"J = jacfwd(deterministic_leapfrog)(X, 1e-1, nn_params)"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "9a021325-dbc6-4dbf-ab3a-8248c82544be",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[1.00829751e+00, 9.24930722e-04, 1.00000000e-01,\n",
" 0.00000000e+00],\n",
" [9.55510605e-04, 1.00028251e+00, 0.00000000e+00,\n",
" 1.00000000e-01],\n",
" [1.48013949e-01, 1.72562925e-02, 1.00644961e+00,\n",
" 7.94508611e-04],\n",
" [1.70953111e-02, 5.11553162e-03, 7.47599173e-04,\n",
" 1.00022829e+00]], dtype=float64)"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"J"
]
},
{
"cell_type": "code",
"execution_count": 164,
"id": "2d711d35-b2ba-47ce-b52f-9d4e3f577021",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f23aebfb970>"
]
},
"execution_count": 164,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANN0lEQVR4nO3dcchd9X3H8fdnWUolOrIRnWmSaWFh0hWm7iFVhOFcHRqE9A8p8Y8qMnio2GJh/lE2cOy//VWYWHSBSg2UdoKtC1264jqHCnU1SZNMzdyCKxgMi7OaGI0r0e/+uEd5ePw9Scw999z7JO8XXJ5z7vnlfn+X5Pnk3HPOPd9UFZK02K9NewKSZpPhIKnJcJDUZDhIajIcJDUZDpKafn2cP5zkt4C/By4HfgF8sareaIz7BfAW8B5wsqrmxqkrafLG3XP4OvCTqtoI/KRbX8ofV9WVBoO0PIwbDluAR7rlR4AvjPl6kmZExrlCMsmbVbV6wfobVfWbjXH/DbwBFPB3VbXtFK85D8wDrFq16g+vuOKKs57frPr57t3TnsLEXDTtCUzIW9OewIS8D1RVWttOGw5J/hm4tLHpL4FHzjAcPlVVrya5BHgC+GpVPXW6ic/NzdWuXbtON2zZWZXm38U54YZpT2BC/mXaE5iQd4H3lgiH0x6QrKrPL7Utyf8kWVtVh5OsBY4s8Rqvdj+PJPkBsAk4bThImp5xjznsAO7olu8A/mHxgCSrklz0wTLwp8DzY9aVNGHjhsPfADcm+S/gxm6dJJ9KsrMb89vAM0n2AT8D/rGq/mnMupImbKzrHKrqdeBPGs+/Cmzull8G/mCcOpKG5xWSkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU29hEOSm5K8lORgko90vcrI/d32/Umu7qOupMkZOxySrAC+CdwMfAa4LclnFg27GdjYPeaBB8etK2my+thz2AQcrKqXq+pXwPcYtclbaAuwvUaeBVZ3fS4kzag+wmEd8MqC9UPdcx93jKQZ0kc4tFppLe6xdyZjRgOT+SS7kux67bXXxp6cpLPTRzgcAjYsWF8PvHoWYwCoqm1VNVdVcxdffHEP05N0NvoIh+eAjUk+neQTwFZGbfIW2gHc3p21uAY4WlWHe6gtaULG6ngFUFUnk3wF+DGwAni4ql5I8uVu+0PATkYdsA4C7wB3jltX0mSNHQ4AVbWTUQAsfO6hBcsF3N1HLUnD8ApJSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVLTUL0yr09yNMne7nFfH3UlTc7YN5hd0CvzRkb9KZ5LsqOqXlw09OmqumXcepKG0cfdpz/slQmQ5INemYvD4WP7+e7drEqrWdby9nadmPYUJmZVLpj2FCbihmlPYEKePsW2oXplAlybZF+SHyX5/aVebGE7vGa/PEmD6GPP4Uz6YO4BLquq40k2A48DG1svVlXbgG0AKxLzQZqSQXplVtWxqjreLe8EViZZ00NtSRMySK/MJJcmo4MHSTZ1dV/vobakCRmqV+atwF1JTgIngK1dizxJMyqz/Du6IqlPTnsSE+DZiuXnXD5b8WZV85SgV0hKajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNfXVDu/hJEeSPL/E9iS5v2uXtz/J1X3UlTQ5fe05fBu46RTbb2bUp2IjMA882FNdSRPSSzhU1VPAL08xZAuwvUaeBVYnWdtHbUmTMdQxhzNtmWc7PGlG9NEO70ycScu80ZO2w5NmwlB7DqdtmSdptgwVDjuA27uzFtcAR6vq8EC1JZ2FXj5WJPkucD2wJskh4K+AlfBhO7ydwGbgIPAOcGcfdSVNTi/hUFW3nWZ7AXf3UUvSMLxCUlKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKlpqHZ41yc5mmRv97ivj7qSJqevvhXfBh4Atp9izNNVdUtP9SRN2FDt8CQtM0N1vAK4Nsk+Rs1s7q2qF1qDkswzarbLhcAdw81vMJfkgmlPYWLernOzSdmqtJq2LX/vnmLbUOGwB7isqo4n2Qw8zqjj9kcsbId3ie3wpKkZ5GxFVR2rquPd8k5gZZI1Q9SWdHYGCYcklyaj/bIkm7q6rw9RW9LZGaod3q3AXUlOAieArV0XLEkzKrP8O3pJUl+c9iQm4NFpT2CCjszwv6dxnMsHJN+rar45r5CU1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIaho7HJJsSPJkkgNJXkhyT2NMktyf5GCS/UmuHreupMnq4wazJ4E/r6o9SS4Cdid5oqpeXDDmZkZ9KjYCnwMe7H5KmlFj7zlU1eGq2tMtvwUcANYtGrYF2F4jzwKrk6wdt7akyen1mEOSy4GrgH9btGkd8MqC9UN8NEA+eI35JLuS7DrR5+QkfSy9hUOSC4HHgK9V1bHFmxt/pHkP86raVlVzVTV37naUlGZfL+GQZCWjYPhOVX2/MeQQsGHB+npGDXUlzag+zlYE+BZwoKq+scSwHcDt3VmLa4CjVXV43NqSJqePsxXXAV8C/j3J3u65vwB+Bz5sh7cT2AwcBN4B7uyhrqQJGjscquoZ2scUFo4p4O5xa0kajldISmoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqclwkNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUN1Q7v+iRHk+ztHveNW1fSZA3VDg/g6aq6pYd6kgYwVDs8SctMH3sOHzpFOzyAa5PsY9TM5t6qemGJ15gH5mGUXI/2OcEZ8e60JzBBq3LKG5EvW2/XudmccW7uuiW39RYOp2mHtwe4rKqOJ9kMPM6o4/ZHVNU2YBvAyqTZMk/S5A3SDq+qjlXV8W55J7AyyZo+akuajEHa4SW5tBtHkk1d3dfHrS1pcoZqh3crcFeSk8AJYGvXBUvSjBqqHd4DwAPj1pI0HK+QlNRkOEhqMhwkNRkOkpoMB0lNhoOkJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGrq4wazn0zysyT7unZ4f90YkyT3JzmYZH+Sq8etK2my+rjB7P8BN3Q9KVYCzyT5UVU9u2DMzYz6VGwEPgc82P2UNKP6aIdXH/SkAFZ2j8V3lt4CbO/GPgusTrJ23NqSJqevpjYrutvSHwGeqKrF7fDWAa8sWD+E/TSlmdZLOFTVe1V1JbAe2JTks4uGtG5d3+xbkWQ+ya4ku97vY3KSzkqvZyuq6k3gX4GbFm06BGxYsL6eUUPd1mtsq6q5qprzVIo0PX2crbg4yepu+QLg88B/LBq2A7i9O2txDXC0qg6PW1vS5PRxtmIt8EiSFYzC5tGq+mGSL8OH7fB2ApuBg8A7wJ091JU0QX20w9sPXNV4/qEFywXcPW4tScPxY72kJsNBUpPhIKnJcJDUZDhIajIcJDUZDpKaDAdJTYaDpCbDQVKT4SCpyXCQ1GQ4SGoyHCQ1GQ6SmgwHSU2Gg6Qmw0FSk+EgqWmoXpnXJzmaZG/3uG/cupIma6hemQBPV9UtPdSTNIA+7j5dwOl6ZUpaZvrYc6DrWbEb+F3gm41emQDXJtnHqNPVvVX1whKvNQ/Md6vHX4OX+pjjGVgD/O9AtYbk++rBqF/TYIZ8b5cttSGj//j70XW++gHw1ap6fsHzvwG833302Az8bVVt7K1wD5Lsqqq5ac+jb76v5WdW3tsgvTKr6lhVHe+WdwIrk6zps7akfg3SKzPJpUnSLW/q6r4+bm1JkzNUr8xbgbuSnAROAFurz88z/dg27QlMiO9r+ZmJ99brMQdJ5w6vkJTUZDhIajrvwyHJTUleSnIwydenPZ++JHk4yZEkz59+9PKRZEOSJ5Mc6C7Xv2fac+rDmXwNYfA5nc/HHLqDqP8J3AgcAp4DbquqF6c6sR4k+SNGV65ur6rPTns+fUmyFlhbVXuSXMTo4rsvLPe/s+5s3qqFX0MA7ml8DWEw5/uewybgYFW9XFW/Ar4HbJnynHpRVU8Bv5z2PPpWVYerak+3/BZwAFg33VmNr0Zm6msI53s4rANeWbB+iHPgH9r5IsnlwFVA63L9ZSfJiiR7gSPAE0t8DWEw53s4pPHc+fs5axlJciHwGPC1qjo27fn0oareq6orgfXApiRT/Th4vofDIWDDgvX1jL4YphnWfSZ/DPhOVX1/2vPp21JfQxja+R4OzwEbk3w6ySeArcCOKc9Jp9AduPsWcKCqvjHt+fTlTL6GMLTzOhyq6iTwFeDHjA5sPbrUV8mXmyTfBX4K/F6SQ0n+bNpz6sl1wJeAGxbcWWzztCfVg7XAk0n2M/pP64mq+uE0J3Ren8qUtLTzes9B0tIMB0lNhoOkJsNBUpPhIKnJcJDUZDhIavp/bOAKhKI3ezQAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(J, cmap='hot')"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "635711e0-a455-4866-9d7a-6a0058f7a360",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(1., dtype=float64)"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.linalg.det(J)"
]
},
{
"cell_type": "markdown",
"id": "c7e17eb0-41f5-4227-8db0-95e44838b5cc",
"metadata": {},
"source": [
"neat-o..."
]
},
{
"cell_type": "markdown",
"id": "a0ff2eb0-eb45-4001-a7d9-9db9f430f895",
"metadata": {},
"source": [
"it should also be the case that if we define \n",
"$$\n",
"s(x, v) = (x, -v)\n",
"$$\n",
"it should be the case that \n",
"\n",
"$$\n",
"s \\circ \\Phi \\circ s = \\Phi^{-1}\n",
"$$\n",
"let's see if we can numerically show this..."
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "e1bb3f0d-9214-435d-a567-b26ec2cd8df5",
"metadata": {},
"outputs": [],
"source": [
"def s(X: Conf) -> Conf:\n",
" _dim = X.shape[0] // 2\n",
" return jnp.concatenate([X[:_dim], -X[_dim:]])"
]
},
{
"cell_type": "code",
"execution_count": 178,
"id": "e48679a6-94cf-420f-8dfb-36bf28c784df",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-0.96499554, 0.59350874, -1.38793319, 0.55717537], dtype=float64)"
]
},
"execution_count": 178,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "markdown",
"id": "4b41fbcd-3caa-4640-8a76-fccd25dcffba",
"metadata": {},
"source": [
"specifically, \n",
"$$\n",
"\\Phi^{-1} \\circ \\Phi(X) = s \\circ \\Phi \\circ s \\circ \\Phi(X)\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 180,
"id": "b1abd634-abdd-47ad-8e74-830572d87fd8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([-0.96499554, 0.59350874, -1.38793319, 0.55717537], dtype=float64)"
]
},
"execution_count": 180,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"s(\n",
" deterministic_leapfrog(s(deterministic_leapfrog(X, 1e-1, nn_params)), 1e-1, nn_params)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1dd76efe-ad4c-400d-b97e-8b13447d9b94",
"metadata": {},
"source": [
"and that does it. we can now define the inverse map!"
]
},
{
"cell_type": "markdown",
"id": "d0158b56-8de7-4c7a-b856-0e62917bed0a",
"metadata": {},
"source": [
"## Annealed Flow Transport\n",
"Let's now make a modified AIS sampler that will run annealed flow transport...\n",
"But first, let's try to build _just_ the trainable block where we do a deterministic map between two distributions...the stochastic kernel is already build and just requires us to append it on the end."
]
},
{
"cell_type": "code",
"execution_count": 491,
"id": "f818a426-af95-441e-a8aa-2ea97a3602bb",
"metadata": {},
"outputs": [],
"source": [
"def make_leapfrog_NF(_t0_sampler: Callable[Seed, Conf], \n",
" _t0_energy_fn: Callable[[Conf, Params], Conf],\n",
" _t1_energy_fn: Callable[[Conf, Params], Conf],\n",
" dt : float,\n",
" features: Optional[Sequence]=[2,2],\n",
" seed: Optional[Seed]=random.PRNGKey(13)\n",
" ) -> Tuple[nnParams, Callable[[Conf, nnParams, Seed, float], Conf]]:\n",
" \n",
" params, leapfrog_nf = make_mod_leapfrog(seed, features)\n",
" _dim = features[0]\n",
" \n",
" def Importance_Sampler(seed: Seed, \n",
" nn_params_as_tuple: Tuple[nnParams], \n",
" t0_energy_params: Params, \n",
" t1_energy_params: Params) -> Tuple[Conf, Conf, float]:\n",
" X_in = _t0_sampler(seed)\n",
" X_out = leapfrog_nf(X_in, dt, nn_params_as_tuple[0])\n",
" \n",
" t0_u, t0_k = _t0_energy_fn(X_in[:_dim], t0_energy_params), kinetic_energy(1., X_in[_dim:])\n",
" t1_u, t1_k = _t1_energy_fn(X_out[:_dim], t1_energy_params), kinetic_energy(1., X_out[_dim:])\n",
" \n",
" work = t1_u + t1_k - t0_u - t0_k\n",
" \n",
" return X_in, X_out, work\n",
" \n",
" return params, Importance_Sampler "
]
},
{
"cell_type": "code",
"execution_count": 492,
"id": "c560fb86-224d-404a-a4b9-647093394d77",
"metadata": {},
"outputs": [],
"source": [
"is_nn_params, isampler = make_leapfrog_NF(_t0_sampler, _energy_fn, _energy_fn, 1.)"
]
},
{
"cell_type": "code",
"execution_count": 493,
"id": "ed33f9d6-0d8f-4e50-8ffa-7df2c01114e8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([ 0.96907176, 0.51501533, -1.43994373, -1.6511834 ], dtype=float64),\n",
" DeviceArray([-0.1030787 , -1.23066942, -0.41721773, -2.18985484], dtype=float64),\n",
" DeviceArray(2.57900905, dtype=float64))"
]
},
"execution_count": 493,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"isampler(random.PRNGKey(34155), (is_nn_params,), jnp.zeros(2), jnp.ones(2))"
]
},
{
"cell_type": "code",
"execution_count": 494,
"id": "645e3901-146d-4c04-8024-b07f3ef6577e",
"metadata": {},
"outputs": [],
"source": [
"visampler = vmap(partial(isampler, t0_energy_params = jnp.zeros(2), t1_energy_params = jnp.ones(2)), in_axes = (0,None))"
]
},
{
"cell_type": "code",
"execution_count": 495,
"id": "5f3c9e32-f350-4899-a74c-58a5da041eed",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[-0.13315785, 1.01386513, -0.6212633 , 0.60865124],\n",
" [ 0.14958146, 0.43117586, 0.05241661, 0.05378837],\n",
" [-0.20707079, -0.55351003, 0.70936913, -0.66941612],\n",
" [-0.86510991, -1.36719591, -1.32867807, -0.09460279],\n",
" [ 0.14172441, -0.01347149, -0.84180898, -0.1111099 ],\n",
" [ 0.53151389, -0.90371657, 0.46633612, 0.28262607],\n",
" [ 0.574772 , 1.10342848, 0.16865889, 0.83267297],\n",
" [-0.36430806, -0.83109545, -1.67787892, -0.75173715],\n",
" [ 0.55795223, 0.4525778 , 0.35620414, -1.81124353],\n",
" [-0.58820621, 1.52723167, -0.21199303, -1.32970617]], dtype=float64),\n",
" DeviceArray([[-1.41461745, 2.05706838, -1.87873927, 1.59475793],\n",
" [-0.16087358, 0.65367767, -0.9466206 , 0.5700123 ],\n",
" [ 0.93157583, -1.43000784, 1.75411696, -1.40996385],\n",
" [-1.62229365, -1.80458228, -0.2986438 , -0.74485972],\n",
" [-0.52942469, -0.18617663, -0.99167394, -0.07049572],\n",
" [ 1.65730698, -1.08574131, 1.73617314, -0.7211537 ],\n",
" [ 0.17391125, 2.25957034, -1.01357087, 1.69156383],\n",
" [-1.51430875, -1.86071866, -0.61031067, -1.38260964],\n",
" [ 0.93583491, -1.32989618, 0.99719163, -2.31176757],\n",
" [-1.42109755, 0.7245177 , -1.46663659, -0.28907386]], dtype=float64),\n",
" DeviceArray([5.60930358, 1.23732164, 4.83698509, 5.49706745, 1.99664333,\n",
" 3.4600858 , 1.9439734 , 6.29286094, 3.92378597, 1.84035386], dtype=float64))"
]
},
"execution_count": 495,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"visampler(random.split(_s, num=10), (is_nn_params,))"
]
},
{
"cell_type": "code",
"execution_count": 496,
"id": "6752d3df-b990-442c-bc57-1937b7306aaa",
"metadata": {},
"outputs": [],
"source": [
"def loss(in_nn_params: nnParams, run_keys: Seed) -> float:\n",
" _, _, works = visampler(run_keys, (in_nn_params,))\n",
" return jnp.mean(works)"
]
},
{
"cell_type": "code",
"execution_count": 497,
"id": "beeb5897-29f7-4876-b3e6-0dc7de15af4c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(3.66383811, dtype=float64)"
]
},
"execution_count": 497,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss(is_nn_params, random.split(_s, num=10))"
]
},
{
"cell_type": "code",
"execution_count": 498,
"id": "5e1f9bc9-4b95-4635-bfd0-c7526fdaa3c3",
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental import optimizers\n",
"from jax import value_and_grad"
]
},
{
"cell_type": "code",
"execution_count": 499,
"id": "950e1dc5-49bf-4343-b85d-431e39eaaa3e",
"metadata": {},
"outputs": [],
"source": [
"opt_init, opt_update, get_params = optimizers.adam(step_size=1e-1)"
]
},
{
"cell_type": "code",
"execution_count": 500,
"id": "3d6c53e7-d9f0-4dbc-9db4-b9f6af676dc6",
"metadata": {},
"outputs": [],
"source": [
"def step(i, opt_state, seeds):\n",
" params = get_params(opt_state)\n",
" _val, g = value_and_grad(loss)(params, seeds)\n",
" return _val, opt_update(i, g, opt_state)\n",
"\n",
"#step = jit(step)"
]
},
{
"cell_type": "code",
"execution_count": 501,
"id": "87d18b3b-9b61-4010-b0b3-9c83eb5b667b",
"metadata": {},
"outputs": [],
"source": [
"iters = int(1e2)\n",
"opt_state = opt_init(is_nn_params)\n",
"parent_seed = random.PRNGKey(781)\n",
"mean_works = []"
]
},
{
"cell_type": "code",
"execution_count": 502,
"id": "87b0ac04-5f52-40c3-8f4d-1f38617f1c45",
"metadata": {},
"outputs": [],
"source": [
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 503,
"id": "fee0c88a-4ebe-4e60-a6ee-fae60ab1a58a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [00:10<00:00, 9.45it/s]\n"
]
}
],
"source": [
"for i in tqdm.trange(iters):\n",
" #parent_seed = random.PRNGKey(781)\n",
" parent_seed, child_seed = random.split(parent_seed)\n",
" run_seeds = random.split(child_seed, num=100)\n",
" _val, opt_state = step(i, opt_state, run_seeds)\n",
" \n",
" mean_works.append(_val)"
]
},
{
"cell_type": "code",
"execution_count": 506,
"id": "d7866f4d-a348-4bbd-9eb4-4655a053f0c5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'mean work (fwd)')"
]
},
"execution_count": 506,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(mean_works)\n",
"plt.xlabel(f\"train epoch\")\n",
"plt.ylabel(f\"mean work (fwd)\")"
]
},
{
"cell_type": "markdown",
"id": "516e069f-4c0e-40e8-a8ca-09c0d59fbd14",
"metadata": {},
"source": [
"above, I conducted the NF part of [AFT](https://arxiv.org/pdf/2102.07501.pdf), considered the case where there is a _single_ NF between my prior/posterior (both Normal distributions in 2D centerd at (0,0), (1,1), respectively). <br>\n",
"I demonstrated that we can train the NF when it is defined by a modified leapfrog integrator as described in eq 14 from [Nonreversible MCMC from conditional invertible transforms](https://arxiv.org/pdf/2012.15550.pdf) (thanks to [Josh Fass](https://github.com/maxentile) for pointing out the neat volume-preserving leapfrog integrator). <br>"
]
},
{
"cell_type": "markdown",
"id": "c6a10281-0c11-4d76-af4a-96b978c71db2",
"metadata": {},
"source": [
"next, let's try to implement this in a multi-step (i.e. N>1 bridging distributions) and see if we can't train this sequentially. hopefully we can observe an improvement over vanilla AIS. "
]
},
{
"cell_type": "markdown",
"id": "83246d4a-3bdf-48be-a036-220d8c9aeef3",
"metadata": {},
"source": [
"afterwards, maybe we can apply to a molecular system?!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83a34043-7208-4399-9505-800236d0fa21",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment