Skip to content

Instantly share code, notes, and snippets.

@smsharma
Created June 13, 2023 19:24
Show Gist options
  • Save smsharma/3a3bbfb4fbea3a0880cb7cc0e4f0e706 to your computer and use it in GitHub Desktop.
Save smsharma/3a3bbfb4fbea3a0880cb7cc0e4f0e706 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# ! pip install powerbox\n",
"# ! pip install git+https://github.com/tlmakinen/powerbox-jax.git\n",
"# ! pip install jax-cosmo\n",
"# ! pip install nflows"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as np\n",
"import jax\n",
"import numpy as onp\n",
"import jax_cosmo as jc\n",
"import powerbox_jax as pbj\n",
"import powerbox as pbox\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulate some spectra"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Define cosmology\n",
"cosmo_params = jc.Planck15(Omega_c=0.4, sigma8=0.6)\n",
"\n",
"# Define power spectrum\n",
"def P(k, Omega_c=0.40, sigma8=0.60):\n",
" cosmo_params = jc.Planck15(Omega_c=Omega_c, sigma8=sigma8)\n",
" return jc.power.linear_matter_power(cosmo_params, k)\n",
"\n",
"N = 50 # Number of points in box\n",
"L = 50. # Mpc; box size\n",
"\n",
"rng = jax.random.PRNGKey(32)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Draw N uniform numbers between (0.1, 0.5) and (0.6, 1.0) as prior\n",
"\n",
"from tqdm.notebook import tqdm\n",
"from scipy.stats import uniform\n",
"\n",
"n_train = 100\n",
"thetas = np.array(uniform.rvs(size=(n_train, 2), loc=(0.1, 0.6), scale=(0.4, 0.4)))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f5cf02dbf8254bfaa363b4d6c073e93a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def lnpb_fn(theta):\n",
" \n",
" lnpb = pbj.LogNormalPowerBox(\n",
" N=N, dim=2, pk = lambda k: P(k, Omega_c=theta[0], sigma8=theta[1]) / L, boxlength = L, key = rng, vol_normalised_power=True, supplied_freqs=None\n",
" )\n",
"\n",
" p_k_lnfield, bins_lnfield = pbox.get_power(lnpb.delta_x(), lnpb.boxlength)\n",
"\n",
" return p_k_lnfield\n",
" \n",
"\n",
"pspecs = [lnpb_fn(theta) for theta in tqdm(thetas)]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper function for constructing normalizing flows with `nflows`"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from nflows import distributions as distributions_\n",
"from nflows import flows, transforms\n",
"from nflows.nn import nets\n",
"import torch\n",
"\n",
"def build_maf(dim=1, num_transforms=8, context_features=None, hidden_features=128):\n",
" transform = transforms.CompositeTransform(\n",
" [\n",
" transforms.CompositeTransform(\n",
" [\n",
" transforms.MaskedAffineAutoregressiveTransform(\n",
" features=dim,\n",
" hidden_features=hidden_features,\n",
" context_features=context_features,\n",
" num_blocks=2,\n",
" use_residual_blocks=False,\n",
" random_mask=False,\n",
" activation=torch.tanh,\n",
" dropout_probability=0.0,\n",
" use_batch_norm=False,\n",
" ),\n",
" transforms.RandomPermutation(features=dim),\n",
" ]\n",
" )\n",
" for _ in range(num_transforms)\n",
" ]\n",
" )\n",
"\n",
" distribution = distributions_.StandardNormal((dim,))\n",
" neural_net = flows.Flow(transform, distribution)\n",
"\n",
" return neural_net\n",
"\n",
"# Test\n",
"flow = build_maf(dim=2, context_features=2)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train normalizing flow"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"from torch.utils.data import TensorDataset, DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/nx/bx2847k56j3dddp761x637pc0000gn/T/ipykernel_6719/229339229.py:1: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:248.)\n",
" pspecs = torch.Tensor(pspecs)\n"
]
}
],
"source": [
"pspecs = torch.Tensor(pspecs)\n",
"thetas = torch.Tensor(onp.array(thetas))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 4\n",
"\n",
"dataset = TensorDataset(pspecs, thetas)\n",
"train_loader = DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class FlowSurrogate(pl.LightningModule):\n",
" def __init__(\n",
" self,\n",
" optimizer=torch.optim.AdamW,\n",
" optimizer_kwargs={\"weight_decay\": 1e-5},\n",
" max_epochs=50,\n",
" scheduler=torch.optim.lr_scheduler.CosineAnnealingLR,\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs\n",
" self.scheduler = scheduler\n",
" self.scheduler_kwargs = {\"T_max\":max_epochs}\n",
" self.lr = 3e-4\n",
"\n",
" self.flow = build_maf(dim=pspecs.shape[-1], context_features=2)\n",
"\n",
" def forward(self, x, theta):\n",
" log_prob = self.flow.log_prob(x, context=theta)\n",
" return log_prob.mean()\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = self.optimizer(self.parameters(), lr=self.lr, **self.optimizer_kwargs)\n",
"\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": {\"scheduler\": self.scheduler(optimizer, **self.scheduler_kwargs), \"interval\": \"epoch\", \"monitor\": \"val_loss\", \"frequency\": 1}}\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x, theta = batch\n",
" log_prob = self(x, theta)\n",
" loss = -log_prob\n",
" self.log(\"train_loss\", loss, on_epoch=True)\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"/opt/homebrew/Caskroom/miniforge/base/envs/torch-mps/lib/python3.9/site-packages/pytorch_lightning/trainer/setup.py:200: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.\n",
" rank_zero_warn(\n",
"\n",
" | Name | Type | Params\n",
"------------------------------\n",
"0 | flow | Flow | 336 K \n",
"------------------------------\n",
"336 K Trainable params\n",
"0 Non-trainable params\n",
"336 K Total params\n",
"1.345 Total estimated model params size (MB)\n",
"/opt/homebrew/Caskroom/miniforge/base/envs/torch-mps/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1595: PossibleUserWarning: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
" rank_zero_warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "24d3038cb22d44bdba0f20113067e21a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=50` reached.\n"
]
}
],
"source": [
"model = FlowSurrogate()\n",
"\n",
"trainer = pl.Trainer(max_epochs=50, accelerator='cpu', devices=1)\n",
"trainer.fit(model=model, train_dataloaders=train_loader)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"model.eval();"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulator function and sampling examples"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def simulator(theta, num_samples=1):\n",
" \"\"\" Simulator by sampling from the trained normalizing flow\n",
" \"\"\"\n",
" return model.flow.sample(num_samples=num_samples, context=theta)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Pk')"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pspec_sample = simulator(thetas[0], num_samples=1)\n",
"plt.plot(pspec_sample[0, 0].detach().numpy(), label='sample', color='C0')\n",
"plt.plot(pspecs[0].detach().numpy(), label='truth', color='C0', ls='--')\n",
"\n",
"pspec_sample = simulator(thetas[5], num_samples=1)\n",
"plt.plot(pspec_sample[0, 0].detach().numpy(), label='sample', color='C1')\n",
"plt.plot(pspecs[5].detach().numpy(), label='truth', color='C1', ls='--')\n",
"\n",
"plt.yscale('log')\n",
"\n",
"plt.legend()\n",
"\n",
"plt.ylabel(\"Pk\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch-mps",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment