Skip to content

Instantly share code, notes, and snippets.

@kvfrans
Created April 24, 2024 18:22
Show Gist options
  • Save kvfrans/45de3b808600f236f32666676cb87dc7 to your computer and use it in GitHub Desktop.
Save kvfrans/45de3b808600f236f32666676cb87dc7 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"TPU_CHIPS_PER_PROCESS_BOUNDS\"] = \"2,2,1\"\n",
"os.environ[\"TPU_PROCESS_BOUNDS\"] = \"1,1,1\"\n",
"os.environ[\"TPU_VISIBLE_DEVICES\"] = \"0,1,2,3\"\n",
"\n",
"from typing import Any, Callable, Optional, Tuple, Type\n",
"\n",
"import flax.linen as nn\n",
"import jax.numpy as jnp\n",
"import os\n",
"from absl import app, flags\n",
"from functools import partial\n",
"import numpy as np\n",
"import tqdm\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import flax\n",
"import optax\n",
"import wandb\n",
"from ml_collections import config_flags\n",
"from flax.training import checkpoints\n",
"import ml_collections\n",
"import matplotlib.pyplot as plt\n",
"import functools"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"FLAGS = {\n",
" 'seed': 0,\n",
" 'batch_size': 256,\n",
" 'max_steps': 200_000,\n",
"}\n",
"model_config = {\n",
" 'lr': 0.001,\n",
" 'diffusion_timesteps': 500,\n",
" 'process': 'diffusion' # or 'flow\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nonpytree_field = functools.partial(flax.struct.field, pytree_node=False)\n",
"class TrainState(flax.struct.PyTreeNode):\n",
" step: int\n",
" apply_fn: Callable[..., Any] = nonpytree_field()\n",
" model_def: Any = nonpytree_field()\n",
" params: Any\n",
" tx: Optional[optax.GradientTransformation] = nonpytree_field()\n",
" opt_state: Optional[optax.OptState] = None\n",
"\n",
" @classmethod\n",
" def create(\n",
" cls,\n",
" model_def: nn.Module,\n",
" params: Any,\n",
" tx: Optional[optax.GradientTransformation] = None,\n",
" **kwargs,\n",
" ) -> \"TrainState\":\n",
" if tx is not None:\n",
" opt_state = tx.init(params)\n",
" else:\n",
" opt_state = None\n",
"\n",
" return cls(\n",
" step=1,\n",
" apply_fn=model_def.apply,\n",
" model_def=model_def,\n",
" params=params,\n",
" tx=tx,\n",
" opt_state=opt_state,\n",
" **kwargs,\n",
" )\n",
"\n",
" def __call__(\n",
" self,\n",
" *args,\n",
" params=None,\n",
" extra_variables: dict = None,\n",
" method: Any = None,\n",
" **kwargs,\n",
" ):\n",
" \"\"\"\n",
" Internally calls model_def.apply_fn with the following logic:\n",
"\n",
" Arguments:\n",
" params: If not None, use these params instead of the ones stored in the model.\n",
" extra_variables: Additional variables to pass into apply_fn\n",
" method: If None, use the `__call__` method of the model_def. If a string, uses\n",
" the method of the model_def with that name (e.g. 'encode' -> model_def.encode).\n",
" If a function, uses that function.\n",
"\n",
" \"\"\"\n",
" if params is None:\n",
" params = self.params\n",
"\n",
" variables = {\"params\": params}\n",
"\n",
" if extra_variables is not None:\n",
" variables = {**variables, **extra_variables}\n",
"\n",
" if isinstance(method, str):\n",
" method = getattr(self.model_def, method)\n",
"\n",
" return self.apply_fn(variables, *args, method=method, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _extract_into_tensor(arr, timesteps, broadcast_shape):\n",
" res = arr[timesteps]\n",
" while len(res.shape) < len(broadcast_shape):\n",
" res = res[..., None]\n",
" return res\n",
"\n",
"class SmallDiffusionModel(nn.Module):\n",
" @nn.compact\n",
" def __call__(self, x, t):\n",
" x = jnp.concatenate([x, t[:, None]], axis=-1)\n",
" for _ in range(4):\n",
" x = nn.Dense(features=64)(x)\n",
" x = nn.silu(x)\n",
" x = nn.Dense(features=2)(x)\n",
" return x\n",
" \n",
"class DiffusionTrainer(flax.struct.PyTreeNode):\n",
" rng: Any\n",
" model: TrainState\n",
" config: dict = flax.struct.field(pytree_node=False)\n",
" scheduler: Any = flax.struct.field(pytree_node=False)\n",
"\n",
" @partial(jax.pmap, axis_name='data')\n",
" def update(self, images, eps=None, pmap_axis='data'):\n",
" new_rng, time_key, noise_key = jax.random.split(self.rng, 3)\n",
"\n",
" def loss_fn(params):\n",
" if self.config['process'] == 'diffusion':\n",
" random_t = jax.random.randint(time_key, (images.shape[0],), 0, self.config['diffusion_timesteps'])\n",
" eps = jax.random.normal(noise_key, images.shape)\n",
" x_t = self.scheduler.q_sample(images, random_t, eps)\n",
" \n",
" eps_prime = self.model(x_t, random_t, params=params)\n",
" l2_loss_eps = jnp.mean((eps_prime - eps) ** 2)\n",
" \n",
" loss = l2_loss_eps\n",
" return loss, {\n",
" 'l2_loss': l2_loss_eps,\n",
" 'eps_abs_mean': jnp.abs(eps).mean(),\n",
" 'eps_pred_abs_mean': jnp.abs(eps_prime).mean(),\n",
" }\n",
" elif self.config['process'] == 'flow':\n",
" random_t = jax.random.uniform(time_key, (images.shape[0],))\n",
" if eps is None:\n",
" eps = jax.random.normal(noise_key, images.shape)\n",
" x_t = self.scheduler.q_sample(images, random_t, eps)\n",
" v_t = self.scheduler.v_sample(images, random_t, eps)\n",
" \n",
" v_prime = self.model(x_t, random_t, params=params)\n",
" l2_loss_eps = jnp.mean((v_t - v_prime) ** 2)\n",
" \n",
" loss = l2_loss_eps\n",
" return loss, {\n",
" 'l2_loss': l2_loss_eps,\n",
" 'v_abs_mean': jnp.abs(v_t).mean(),\n",
" 'v_pred_abs_mean': jnp.abs(v_prime).mean(),\n",
" }\n",
" \n",
" grads, info = jax.grad(loss_fn, has_aux=True)(self.model.params)\n",
" grads = jax.lax.pmean(grads, axis_name=pmap_axis)\n",
" info = jax.lax.pmean(info, axis_name=pmap_axis)\n",
"\n",
" updates, new_opt_state = self.model.tx.update(grads, self.model.opt_state, self.model.params)\n",
" new_params = optax.apply_updates(self.model.params, updates)\n",
" new_model = self.model.replace(step=self.model.step + 1, params=new_params, opt_state=new_opt_state)\n",
"\n",
" info['grad_norm'] = optax.global_norm(grads)\n",
" info['update_norm'] = optax.global_norm(updates)\n",
" info['param_norm'] = optax.global_norm(new_params)\n",
"\n",
" new_trainer = self.replace(rng=new_rng, model=new_model)\n",
" return new_trainer, info\n",
" \n",
" @partial(jax.jit)\n",
" def call_model(self, images, t):\n",
" return self.model(images, t)\n",
" \n",
" @partial(jax.pmap, axis_name='data')\n",
" def call_model_pmap(self, images, t):\n",
" return self.call_model(images, t)\n",
" \n",
" @partial(jax.pmap, axis_name='data')\n",
" def denoise_step(self, x, t, rng, pmap_axis='data'):\n",
" t = jnp.full((x.shape[0],), t)\n",
" eps_prime = self.call_model(x, t)\n",
" mean, variance, log_variance = self.scheduler.p_mean_variance(x, t, eps_prime)\n",
" x = mean + jnp.exp(0.5 * log_variance) * jax.random.normal(rng, x.shape)\n",
" return x\n",
" \n",
" @partial(jax.pmap, axis_name='data')\n",
" def denoise_step_ddim(self, x, t, rng, pmap_axis='data'):\n",
" t = jnp.full((x.shape[0],), t)\n",
" eps_prime = self.call_model(x, t)\n",
" x = self.scheduler.p_ddim(x, t, eps_prime)\n",
" return x\n",
"\n",
" @partial(jax.pmap, axis_name='data')\n",
" def denoise_step_flow(self, x, t, rng, time_interval, pmap_axis='data'):\n",
" t = jnp.full((x.shape[0],), t)\n",
" v_prime = self.call_model(x, t)\n",
" x = x + v_prime * time_interval\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class GaussianDiffusion:\n",
" def __init__(self, num_diffusion_timesteps):\n",
" scale = 1000 / num_diffusion_timesteps\n",
" beta_start = scale * 0.0001\n",
" beta_end = scale * 0.02\n",
" betas = jnp.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=jnp.float64)\n",
" assert len(betas.shape) == 1, \"betas must be 1-D\"\n",
" assert (betas > 0).all() and (betas <= 1).all()\n",
"\n",
" self.num_timesteps = int(betas.shape[0])\n",
"\n",
" alphas = 1.0 - betas\n",
" self.alphas_cumprod = jnp.cumprod(alphas, axis=0)\n",
" self.alphas_cumprod_prev = jnp.append(1.0, self.alphas_cumprod[:-1])\n",
" self.alphas_cumprod_next = jnp.append(self.alphas_cumprod[1:], 0.0)\n",
" assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)\n",
"\n",
" # calculations for diffusion q(x_t | x_{t-1}) and others\n",
" self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod)\n",
" self.sqrt_one_minus_alphas_cumprod = jnp.sqrt(1.0 - self.alphas_cumprod)\n",
" self.log_one_minus_alphas_cumprod = jnp.log(1.0 - self.alphas_cumprod)\n",
" self.sqrt_recip_alphas_cumprod = jnp.sqrt(1.0 / self.alphas_cumprod)\n",
" self.sqrt_recipm1_alphas_cumprod = jnp.sqrt(1.0 / self.alphas_cumprod - 1)\n",
"\n",
" # calculations for posterior q(x_{t-1} | x_t, x_0)\n",
" self.posterior_variance = (\n",
" betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n",
" )\n",
" # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n",
" self.posterior_log_variance_clipped = jnp.log(\n",
" jnp.append(self.posterior_variance[1], self.posterior_variance[1:])\n",
" ) if len(self.posterior_variance) > 1 else jnp.array([])\n",
"\n",
" # (beta_t * atm1.sqrt()) / (1.0 - at)\n",
" self.posterior_mean_coef1 = (\n",
" betas * jnp.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n",
" )\n",
" # ((1 - atm1) * (1 - beta_t).sqrt()) / (1.0 - at)\n",
" self.posterior_mean_coef2 = (\n",
" (1.0 - self.alphas_cumprod_prev) * jnp.sqrt(alphas) / (1.0 - self.alphas_cumprod)\n",
" )\n",
" \n",
" def q_sample(self, x_start, t, eps): # q(x_t | x_0)\n",
" \"\"\" q(x_t | x_0) \"\"\"\n",
" assert eps.shape == x_start.shape\n",
" return (\n",
" _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n",
" + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * eps\n",
" ) \n",
" \n",
" def _predict_xstart_from_eps(self, x_t, t, eps):\n",
" assert x_t.shape == eps.shape\n",
" return (\n",
" _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n",
" - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps\n",
" )\n",
"\n",
" def p_mean_variance(self, x_t, t, eps, clip=True):\n",
" \"\"\"p(x_{t-1} | x_t \"\"\"\n",
" # TODO: Handle learned variance.\n",
" \n",
" model_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)\n",
" model_log_variance = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)\n",
" pred_xstart = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=eps)\n",
" pred_xstart_clipped = jnp.clip(pred_xstart, -1, 1)\n",
" pred_xstart = jnp.where(clip, pred_xstart_clipped, pred_xstart)\n",
"\n",
" model_mean = _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * pred_xstart \\\n",
" + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n",
" return model_mean, model_variance, model_log_variance\n",
"\n",
" def p_ddim(self, x_t, t, eps):\n",
" at = _extract_into_tensor(self.alphas_cumprod, t, x_t.shape)\n",
" at_next = _extract_into_tensor(self.alphas_cumprod_prev, t, x_t.shape)\n",
" c2 = ((1 - at_next)).sqrt()\n",
" x0_t = (x_t - eps * (1 - at).sqrt()) / at.sqrt()\n",
" xt_next = at_next.sqrt() * x0_t + c2 * eps\n",
" return xt_next\n",
" \n",
"class RectifiedFlow():\n",
" def q_sample(self, x_start, t, eps): # q(x_t | x_0)\n",
" \"\"\" q(x_t | x_0). t=1 is fully noised, t=0 is the data. \"\"\"\n",
" assert eps.shape == x_start.shape\n",
" return (1-t[...,None]) * x_start + (t[...,None]) * eps\n",
" \n",
" def v_sample(self, x_start, t, eps):\n",
" return x_start - eps"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"##############################################\n",
"## Training Code.\n",
"##############################################\n",
"np.random.seed(FLAGS['seed'])\n",
"print(\"Using devices\", jax.local_devices())\n",
"device_count = len(jax.local_devices())\n",
"global_device_count = jax.device_count()\n",
"print(\"Device count\", device_count)\n",
"print(\"Global device count\", global_device_count)\n",
"local_batch_size = FLAGS['batch_size'] // (global_device_count // device_count)\n",
"\n",
"def get_data(n=100000):\n",
" from sklearn.datasets import make_swiss_roll\n",
" x, _ = make_swiss_roll(n, noise=0.5)\n",
" x = x[:, [0, 2]]\n",
" x /= 15\n",
" return x.astype('float32')\n",
"dataset = get_data(FLAGS['batch_size'] * 1000)\n",
"\n",
"rng = jax.random.PRNGKey(FLAGS['seed'])\n",
"rng, param_key = jax.random.split(rng, 2)\n",
"print(\"Total Memory on device:\", float(jax.local_devices()[0].memory_stats()['bytes_limit']) / 1024**3, \"GB\")\n",
"\n",
"model_def = SmallDiffusionModel()\n",
"\n",
"example_obs = dataset[:local_batch_size]\n",
"example_t = jnp.zeros((local_batch_size,))\n",
"model_rngs = {'params': param_key}\n",
"params = model_def.init(model_rngs, example_obs, example_t)['params']\n",
"tx = optax.adam(learning_rate=model_config['lr'])\n",
"model_ts = TrainState.create(model_def, params, tx=tx)\n",
"if model_config['process'] == 'diffusion':\n",
" scheduler = GaussianDiffusion(model_config['diffusion_timesteps'])\n",
"else:\n",
" scheduler = RectifiedFlow()\n",
"model = DiffusionTrainer(rng, model_ts, model_config, scheduler)\n",
"model = flax.jax_utils.replicate(model, devices=jax.local_devices())\n",
"\n",
"# Reflow\n",
"if model_config['process'] == 'flow':\n",
" params = model_def.init(model_rngs, example_obs, example_t)['params']\n",
" tx = optax.adam(learning_rate=model_config['lr'])\n",
" model_ts = TrainState.create(model_def, params, tx=tx)\n",
" model_reflowed = DiffusionTrainer(rng, model_ts, model_config, scheduler)\n",
" model_reflowed = flax.jax_utils.replicate(model_reflowed, devices=jax.local_devices())\n",
"\n",
"\n",
"###################################\n",
"# Train Loop\n",
"###################################\n",
"eps_rng = jax.random.PRNGKey(42)\n",
"\n",
"for i in tqdm.tqdm(range(1, FLAGS['max_steps'] + 1),\n",
" smoothing=0.1,\n",
" dynamic_ncols=True):\n",
" \n",
" random_data_ids = np.random.choice(len(dataset), FLAGS['batch_size'] * device_count, replace=False)\n",
" batch_images = dataset[random_data_ids]\n",
" batch_images = batch_images.reshape((device_count, -1, *batch_images.shape[1:])) # [devices, batch//devices, etc..]\n",
"\n",
" model, update_info = model.update(batch_images)\n",
"\n",
" if i > FLAGS['max_steps'] // 2 and model_config['process'] == 'flow':\n",
" # reflow\n",
" eps_rng, eps_key = jax.random.split(eps_rng)\n",
" eps = jax.random.normal(eps_key, batch_images.shape)\n",
" x = eps\n",
" # solve for x using the base model.\n",
" timesteps = model_config['diffusion_timesteps']\n",
" for ti in range(timesteps):\n",
" t = jnp.full((x.shape[0], x.shape[1]), timesteps-ti) # [devices, batch//devices]\n",
" t = t / timesteps\n",
" x = model.denoise_step(x, t, None)\n",
" model_reflowed, update_info = model_reflowed.update(x, eps)\n",
"\n",
" if i % 100_000 == 0:\n",
" key = jax.random.PRNGKey(42 + jax.process_index() + i)\n",
" x = jax.random.normal(key, batch_images.shape) # [devices, batch//devices, etc..]\n",
" key = flax.jax_utils.replicate(key, devices=jax.local_devices())\n",
" key += jnp.arange(len(jax.local_devices()), dtype=jnp.uint32)[:, None] * 1000\n",
" vmap_split = jax.vmap(jax.random.split, in_axes=(0))\n",
" all_x = []\n",
" all_x_short = []\n",
" for ti in range(model_config['diffusion_timesteps']):\n",
" rng, key = jnp.split(vmap_split(key), 2, axis=-1)\n",
" rng, key = rng[...,0], key[...,0]\n",
" t = jnp.full((x.shape[0], x.shape[1]), model_config['diffusion_timesteps']-ti) # [devices, batch//devices]\n",
" if model_config['process'] == 'flow':\n",
" t = t / model_config['diffusion_timesteps']\n",
" x = model.denoise_step(x, t, rng)\n",
" if ti % (model_config['diffusion_timesteps'] // 16) == 0 or ti == model_config['diffusion_timesteps']-1:\n",
" all_x_short.append(np.array(x))\n",
" all_x.append(np.array(x))\n",
" all_x_short = np.stack(all_x_short, axis=2) # [devices, batch//devices, timesteps, XY]\n",
" all_x_short = all_x_short.reshape((-1, *all_x_short.shape[2:]))\n",
" all_x = np.stack(all_x, axis=2) # [devices, batch//devices, timesteps, XY]\n",
" all_x_flat = all_x.reshape((-1, *all_x.shape[2:]))\n",
"\n",
" # plot comparison witah matplotlib. put each reconstruction side by side.\n",
" fig, axs = plt.subplots(4, 4, figsize=(30, 30))\n",
" axs_flat = axs.flatten()\n",
" for t in range(16):\n",
" axs_flat[t].axis(xmin=-2, xmax=2, ymin=-2, ymax=2)\n",
" axs_flat[t].plot(all_x_short[:, t, 0], all_x_short[:, t, 1], '.')\n",
" plt.show()\n",
" plt.close(fig)\n",
"\n",
" # plot a video.\n",
" imgs = []\n",
" for t in tqdm.tqdm(range(model_config['diffusion_timesteps'])):\n",
" if t % 10 != 0:\n",
" continue\n",
" fig, ax = plt.subplots(figsize=(4, 4))\n",
" ax.axis('off')\n",
" ax.axis(xmin=-2, xmax=2, ymin=-2, ymax=2)\n",
" ax.plot(all_x_flat[:, t, 0], all_x_flat[:, t, 1], '.', markersize=1)\n",
" fig.canvas.draw()\n",
" img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
" img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
" imgs.append(img)\n",
" plt.close(fig)\n",
"\n",
" # show gif\n",
" import imageio\n",
" imageio.mimsave(f'vid.gif', imgs, fps=100)\n",
" from IPython.display import Image, display\n",
" display(Image(filename=f'vid.gif'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "project-brc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment