Skip to content

Instantly share code, notes, and snippets.

@smsharma
Created July 8, 2023 05:57
Show Gist options
  • Save smsharma/9b781aa9310342993bc102ddf9af9852 to your computer and use it in GitHub Desktop.
Save smsharma/9b781aa9310342993bc102ddf9af9852 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 370,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax\n",
"from jax import random\n",
"\n",
"from flax import linen as nn\n",
"from einops import rearrange\n",
"import optax\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"from tqdm import tqdm, trange\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow_probability.substrates.jax as tfp"
]
},
{
"cell_type": "code",
"execution_count": 371,
"metadata": {},
"outputs": [],
"source": [
"from jetnet.datasets import JetNet\n",
"\n",
"particle_data, jet_data = JetNet.getData(jet_type='q', data_dir=\"../data/\", num_particles=30)"
]
},
{
"cell_type": "code",
"execution_count": 372,
"metadata": {},
"outputs": [],
"source": [
"n_feat = 3\n",
"n_particles = 15\n",
"\n",
"x = particle_data[:, :n_particles, :n_feat]\n",
"\n",
"x_std = x.std(axis=(0,1))\n",
"x_mean = x.mean(axis=(0,1))\n",
"\n",
"# Normalize\n",
"x = (x - x_mean) / x_std\n",
"\n",
"x_flattened = rearrange(x, 'b p c -> b (p c)')"
]
},
{
"cell_type": "code",
"execution_count": 373,
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" \"\"\" A simple MLP in Flax. This is the score function.\n",
" \"\"\"\n",
" hidden_dim: int = 128\n",
" out_dim: int = 2\n",
" n_layers: int = 4\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" for _ in range(self.n_layers):\n",
" x = nn.Dense(features=self.hidden_dim)(x)\n",
" x = nn.gelu(x)\n",
" x = nn.Dense(features=self.out_dim)(x)\n",
" return x\n"
]
},
{
"cell_type": "code",
"execution_count": 374,
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"\n",
"import ott\n",
"from ott.geometry import costs, geometry, pointcloud\n",
"from ott.problems.linear import linear_problem\n",
"\n",
"from ott.solvers.linear import acceleration, sinkhorn\n",
"from ott.tools.sinkhorn_divergence import sinkhorn_divergence"
]
},
{
"cell_type": "code",
"execution_count": 447,
"metadata": {},
"outputs": [],
"source": [
"class VAE(nn.Module):\n",
" \"\"\" A simple variational auto-encoder module.\n",
" \"\"\"\n",
" num_latents: int = 4\n",
" num_out: int = 2\n",
"\n",
" def setup(self):\n",
" self.encoder = MLP(out_dim=self.num_latents * 2)\n",
" self.decoder = MLP(out_dim=self.num_out)\n",
"\n",
" def __call__(self, x, beta, z_rng):\n",
"\n",
" # Concatenate x and beta\n",
" x = jnp.concatenate([x, beta], axis=-1)\n",
"\n",
" # Get variational parameters from encoder\n",
" enc = self.encoder(x) # Shape (batch_size, num_latents * 2)\n",
" enc = rearrange(enc, 'b (n c) -> b n c', c=2) # Reshape to (batch_size, num_latents, 2)\n",
" mu, logvar = enc[:, :, 0], enc[:, :, 1]\n",
"\n",
" # Sample from variational distrib. of latents\n",
" z = tfp.distributions.Normal(loc=mu, scale=jnp.exp(0.5 * logvar)).sample(seed=z_rng)\n",
"\n",
" # Concatenate z and beta\n",
" z = jnp.concatenate([z, beta], axis=-1)\n",
" \n",
" # Decode\n",
" recon_x = self.decoder(z)\n",
"\n",
" return recon_x, mu, logvar\n",
"\n",
"@jax.vmap\n",
"def kl_divergence(mu, logvar):\n",
" \"\"\" KL-divergence between latent variational distribution and unit Normal prior\n",
" \"\"\"\n",
" prior_latent = tfp.distributions.Normal(loc=0., scale=1.) # Prior\n",
" q_latent = tfp.distributions.Normal(loc=mu, scale=jnp.exp(0.5 * logvar)) # Variational latent distrib.\n",
"\n",
" return tfp.distributions.kl_divergence(q_latent, prior_latent)\n",
"\n",
"@jax.vmap\n",
"def reco_gap(pred, true, beta=0.01):\n",
" \"\"\" Gaussian MSE\n",
" \"\"\"\n",
" log_prob = tfp.distributions.Normal(loc=pred, scale=beta).log_prob(true)\n",
" return -log_prob\n",
"\n",
"# @jax.jit\n",
"# def get_sinkhorn(x1, x2):\n",
"# geom = pointcloud.PointCloud(x1, x2)\n",
"# ot_prob = linear_problem.LinearProblem(geom)\n",
"# solver = sinkhorn.Sinkhorn()\n",
"# ot = solver(ot_prob)\n",
"# return ot.reg_ot_cost\n",
"\n",
"# @jax.vmap\n",
"# def reco_gap(pred, true, beta=1.):\n",
"# \"\"\" Sinkhorn reco\n",
"# \"\"\"\n",
"\n",
"# pred_unflatten = rearrange(pred, '(n c) -> n c', c=n_feat)\n",
"# true_unflatten = rearrange(true, '(n c) -> n c', c=n_feat)\n",
"\n",
"# sinkhorn_div = 1 / beta * get_sinkhorn(pred_unflatten, true_unflatten)\n",
"# return sinkhorn_div"
]
},
{
"cell_type": "code",
"execution_count": 448,
"metadata": {},
"outputs": [],
"source": [
"num_latents = 64\n",
"num_out = x_flattened.shape[-1]\n",
"\n",
"vae = VAE(num_latents=num_latents, num_out=num_out)\n",
"key = jax.random.PRNGKey(42)\n",
"key, z_key = random.split(key)\n",
"_, params = vae.init_with_output(key, x_flattened[:16], jnp.ones((16, 1)), z_key)\n"
]
},
{
"cell_type": "code",
"execution_count": 449,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4602346500.0\n"
]
}
],
"source": [
"@jax.jit\n",
"def loss_fn(params, x_batch, log_beta_batch, z_rng):\n",
"\n",
" beta_batch = jnp.power(10., log_beta_batch)\n",
" \n",
" recon_x, mean, logvar = vae.apply(params, x_batch, beta_batch, z_rng)\n",
"\n",
" reco_loss = reco_gap(recon_x, x_batch, beta_batch).mean(-1)\n",
" kld_loss = kl_divergence(mean, logvar).mean(-1)\n",
"\n",
" loss = reco_loss + kld_loss\n",
" return loss.mean()\n",
"\n",
"print(loss_fn(params, x_flattened[:128], jnp.ones((128, 1)) * -5., key))"
]
},
{
"cell_type": "code",
"execution_count": 450,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5000/5000 [01:13<00:00, 67.73it/s, val=2330.8193] \n"
]
}
],
"source": [
"n_steps = 5000\n",
"n_batch = 128\n",
"\n",
"lr_schedule = optax.warmup_cosine_decay_schedule(\n",
" init_value=0.0,\n",
" peak_value=1e-3,\n",
" warmup_steps=int(0.1 * n_steps),\n",
" decay_steps=n_steps - int(0.1 * n_steps),\n",
" end_value=0.0\n",
")\n",
"\n",
"opt = optax.adam(learning_rate=lr_schedule)\n",
"opt_state = opt.init(params)\n",
"\n",
"with trange(n_steps) as steps:\n",
" for step in steps:\n",
"\n",
" # Draw a random batches from x\n",
" key, subkey = jax.random.split(key)\n",
" idx = jax.random.choice(key, x.shape[0], shape=(n_batch,))\n",
" \n",
" x_batch = x_flattened[idx]\n",
"\n",
" # Draw random batch of betas\n",
" key, subkey = jax.random.split(key)\n",
" log_beta_batch = jax.random.uniform(key, shape=(n_batch, 1), minval=-3., maxval=1.)\n",
"\n",
" # Get loss and update\n",
" loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, log_beta_batch, key)\n",
" updates, opt_state = opt.update(grads, opt_state, params)\n",
"\n",
" params = optax.apply_updates(params, updates)\n",
"\n",
" steps.set_postfix(val=loss)"
]
},
{
"cell_type": "code",
"execution_count": 531,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, '$D$')"
]
},
"execution_count": 531,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"@jax.jit\n",
"def eval(params, z):\n",
" def eval_model(vae):\n",
" return vae.decoder(z)\n",
" return nn.apply(eval_model, vae)(params)\n",
"\n",
"@jax.jit\n",
"def encode_decode(params, x, beta, key):\n",
" x = jnp.concatenate([x, beta], axis=-1)\n",
" def encode_decode_model(vae):\n",
" enc = vae.encoder(x)\n",
" enc = rearrange(enc, 'b (n c) -> b n c', c=2)\n",
" mu, logvar = enc[:, :, 0], enc[:, :, 1]\n",
" z = tfp.distributions.Normal(loc=mu, scale=jnp.exp(0.5 * logvar)).sample(seed=key)\n",
" z = jnp.concatenate([z, beta], axis=-1)\n",
" return vae.decoder(z)\n",
" return nn.apply(encode_decode_model, vae)(params)\n",
"\n",
"def get_RD(params, x, beta, key):\n",
" x = jnp.concatenate([x, beta], axis=-1)\n",
" def encode_decode_model(vae):\n",
" enc = vae.encoder(x)\n",
" enc = rearrange(enc, 'b (n c) -> b n c', c=2)\n",
" mu, logvar = enc[:, :, 0], enc[:, :, 1]\n",
" q_latent = tfp.distributions.Normal(loc=mu, scale=jnp.exp(0.5 * logvar))\n",
" prior_latent = tfp.distributions.Normal(loc=0., scale=1.)\n",
" R = tfp.distributions.kl_divergence(q_latent, prior_latent)\n",
" z = q_latent.sample(seed=key)\n",
" z = jnp.concatenate([z, beta], axis=-1)\n",
" return vae.decoder(z), R\n",
" \n",
" reco, R = nn.apply(encode_decode_model, vae)(params)\n",
" D = reco_gap(reco, x[..., :-1], jnp.ones((x.shape[0], 1)))\n",
" return R.sum(-1).mean(), D.sum(-1).mean()\n",
"\n",
"ii = 20\n",
"\n",
"log_beta_ary = jnp.linspace(-3., 1., 100)\n",
"\n",
"R_list = []\n",
"D_list = []\n",
"\n",
"n_agg = 16\n",
"\n",
"for log_beta in log_beta_ary:\n",
" beta = jnp.ones((n_agg, 1)) * 10 ** log_beta\n",
" R, D = get_RD(params, x_flattened[ii:ii + n_agg], beta, key)\n",
" R_list.append(R)\n",
" D_list.append(D)\n",
"\n",
"plt.plot(R_list, D_list, 'o-')\n",
"plt.xlabel(r'$R$')\n",
"plt.ylabel(r'$D$')"
]
},
{
"cell_type": "code",
"execution_count": 532,
"metadata": {},
"outputs": [],
"source": [
"ii = 4\n",
"beta = jnp.ones((1, 1)) * 1.e-3\n",
"x_samples = encode_decode(params, x_flattened[ii:ii + 1], beta, key)[0]"
]
},
{
"cell_type": "code",
"execution_count": 533,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x349f36400>"
]
},
"execution_count": 533,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x_samples_unflatten = rearrange(x_samples, '(n c) -> n c', c=n_feat)\n",
"\n",
"# Plot samples, first two is xy and last is s\n",
"plt.scatter(x_samples_unflatten[:, 0], x_samples_unflatten[:, 1], s=100 * np.exp(x_samples_unflatten[:, 2]))\n",
"plt.scatter(x[ii, :, 0], x[ii, :, 1], s=100 * np.exp(x[ii, :, 2]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment