Created
October 6, 2022 03:05
-
-
Save kazewong/a912028e4f9e6cdcb8c51765f65241ca to your computer and use it in GitHub Desktop.
Notebook for running GW150914 with FlowMC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "0LM6rDlQI1J2" | |
}, | |
"source": [ | |
"1. Make sure you set your run time device to GPU\n", | |
"2. Download the data file from https://drive.google.com/file/d/1Jalw1mJ4_Cvkp_QjpvKIJS1HYA19SgBL/view?usp=sharing\n", | |
"3. Upload the data to this local environment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "g_Lqr_OpQ_OQ" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install flowMC jaxGW lalsuite\n", | |
"!pip install --upgrade \"jax[cuda]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n", | |
"!git clone https://github.com/maxisi/ripple.git\n", | |
"%cd ripple\n", | |
"!pip install ." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "0sX3Ja6wRJfd" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import jax.numpy as jnp\n", | |
"import jax\n", | |
"from lal import GreenwichMeanSiderealTime\n", | |
"\n", | |
"from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_polar\n", | |
"from jaxgw.PE.detector_preset import * \n", | |
"from jaxgw.PE.heterodyneLikelihood import make_heterodyne_likelihood\n", | |
"from jaxgw.PE.detector_projection import make_detector_response\n", | |
"\n", | |
"from flowMC.nfmodel.rqSpline import RQSpline\n", | |
"from flowMC.sampler.MALA import make_mala_sampler, mala_sampler_autotune\n", | |
"from flowMC.sampler.Sampler import Sampler\n", | |
"from flowMC.utils.PRNG_keys import initialize_rng_keys\n", | |
"from flowMC.nfmodel.utils import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "Z5pm-3DcaaNx" | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"data = np.load('/content/GW150914_data.npz',allow_pickle=True)\n", | |
"\n", | |
"minimum_frequency = data['minimum_frequency']\n", | |
"\n", | |
"H1_frequency = data['frequency'].tolist()['H1']\n", | |
"H1_data = data['data'].tolist()['H1'][H1_frequency>minimum_frequency]\n", | |
"H1_psd = data['psd'].tolist()['H1'][H1_frequency>minimum_frequency]\n", | |
"H1_frequency = H1_frequency[H1_frequency>minimum_frequency]\n", | |
"\n", | |
"L1_frequency = data['frequency'].tolist()['L1']\n", | |
"L1_data = data['data'].tolist()['L1'][L1_frequency>minimum_frequency]\n", | |
"L1_psd = data['psd'].tolist()['L1'][L1_frequency>minimum_frequency]\n", | |
"L1_frequency = L1_frequency[L1_frequency>minimum_frequency]\n", | |
"\n", | |
"H1 = get_H1()\n", | |
"H1_response = make_detector_response(H1[0], H1[1])\n", | |
"L1 = get_L1()\n", | |
"L1_response = make_detector_response(L1[0], L1[1])\n", | |
"\n", | |
"trigger_time = 1126259462.4\n", | |
"duration = 4 \n", | |
"post_trigger_duration = 2\n", | |
"epoch = duration - post_trigger_duration\n", | |
"gmst = GreenwichMeanSiderealTime(trigger_time)\n", | |
"f_ref = 20\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "G6cm0OTiJIcQ" | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"ref_param = jnp.array([ 3.14562104e+01, 2.49863038e-01, 3.01633871e-02, 1.44865987e-02,\n", | |
" 5.41706887e+02, -1.42249470e-02, 1.90571654e+00, 3.02521496e+00,\n", | |
" 1.13895265e+00, 1.80359165e+00, -1.28078777e+00])\n", | |
"\n", | |
"\n", | |
"from jaxgw.PE.heterodyneLikelihood import make_heterodyne_likelihood_mutliple_detector\n", | |
"\n", | |
"data_list = [H1_data, L1_data]\n", | |
"psd_list = [H1_psd, L1_psd]\n", | |
"response_list = [H1_response, L1_response]\n", | |
"\n", | |
"logL = make_heterodyne_likelihood_mutliple_detector(data_list, psd_list, response_list, gen_IMRPhenomD_polar, ref_param, H1_frequency, gmst, epoch, f_ref, 101)\n", | |
"\n", | |
"\n", | |
"n_dim = 11\n", | |
"n_chains = 300\n", | |
"n_loop_training = 50\n", | |
"n_loop_production = 20\n", | |
"n_local_steps = 200\n", | |
"n_global_steps = 200\n", | |
"learning_rate = 0.001\n", | |
"max_samples = 50000\n", | |
"momentum = 0.9\n", | |
"num_epochs = 60\n", | |
"batch_size = 50000\n", | |
"\n", | |
"guess_param = ref_param\n", | |
"\n", | |
"guess_param = np.array(jnp.repeat(guess_param[None,:],int(n_chains),axis=0)*np.random.normal(loc=1,scale=0.1,size=(int(n_chains),n_dim)))\n", | |
"guess_param[guess_param[:,1]>0.25,1] = 0.249\n", | |
"guess_param[:,6] = (guess_param[:,6]%(2*jnp.pi))\n", | |
"guess_param[:,7] = (guess_param[:,7]%(jnp.pi))\n", | |
"guess_param[:,8] = (guess_param[:,8]%(jnp.pi))\n", | |
"guess_param[:,9] = (guess_param[:,9]%(2*jnp.pi))\n", | |
"\n", | |
"\n", | |
"print(\"Preparing RNG keys\")\n", | |
"rng_key_set = initialize_rng_keys(n_chains, seed=42)\n", | |
"\n", | |
"print(\"Initializing MCMC model and normalizing flow model.\")\n", | |
"\n", | |
"prior_range = jnp.array([[10,80],[0.10,0.25],[0,1],[0,1],[0,2000],[-0.1,0.1],[0,2*np.pi],[0,np.pi],[0,np.pi],[0,2*np.pi],[-jnp.pi/2,jnp.pi/2]])\n", | |
"\n", | |
"initial_position = jax.random.uniform(rng_key_set[0], shape=(int(n_chains), n_dim)) * 1\n", | |
"for i in range(n_dim):\n", | |
" initial_position = initial_position.at[:,i].set(initial_position[:,i]*(prior_range[i,1]-prior_range[i,0])+prior_range[i,0])\n", | |
"\n", | |
"# initial_position = initial_position.at[:,0].set(guess_param[:,0])\n", | |
"# initial_position = initial_position.at[:,1].set(guess_param[:,1])\n", | |
"\n", | |
"def top_hat(x):\n", | |
" output = 0.\n", | |
" for i in range(n_dim):\n", | |
" output = jax.lax.cond(x[i]>=prior_range[i,0], lambda: output, lambda: -jnp.inf)\n", | |
" output = jax.lax.cond(x[i]<=prior_range[i,1], lambda: output, lambda: -jnp.inf)\n", | |
" return output\n", | |
"\n", | |
"def posterior(theta):\n", | |
" prior = top_hat(theta)\n", | |
" return logL(theta) + prior\n", | |
"\n", | |
"model = RQSpline(n_dim, 6, [64,64], 8)\n", | |
"\n", | |
"print(\"Initializing sampler class\")\n", | |
"\n", | |
"posterior = posterior\n", | |
"\n", | |
"mass_matrix = jnp.eye(n_dim)\n", | |
"mass_matrix = mass_matrix.at[1,1].set(1e-3)\n", | |
"mass_matrix = mass_matrix.at[5,5].set(1e-3)\n", | |
"\n", | |
"local_sampler_caller = lambda x: make_mala_sampler(x, jit=True)\n", | |
"sampler_params = {'dt':mass_matrix*3e-3}\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "uw06jYnFJo0I" | |
}, | |
"outputs": [], | |
"source": [ | |
"print(\"Running sampler\")\n", | |
"\n", | |
"nf_sampler = Sampler(\n", | |
" n_dim,\n", | |
" rng_key_set,\n", | |
" local_sampler_caller,\n", | |
" sampler_params,\n", | |
" posterior,\n", | |
" model,\n", | |
" n_loop_training=n_loop_training,\n", | |
" n_loop_production = n_loop_production,\n", | |
" n_local_steps=n_local_steps,\n", | |
" n_global_steps=n_global_steps,\n", | |
" n_chains=n_chains,\n", | |
" n_epochs=num_epochs,\n", | |
" learning_rate=learning_rate,\n", | |
" momentum=momentum,\n", | |
" batch_size=batch_size,\n", | |
" use_global=True,\n", | |
" keep_quantile=0.,\n", | |
")\n", | |
"\n", | |
"nf_sampler.sample(initial_position)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "xhSYIpz9KAMM" | |
}, | |
"outputs": [], | |
"source": [ | |
"prod = nf_sampler.get_sampler_state(training=True)\n", | |
"chains, log_prob, local_accs, global_accs, loss = prod.values()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "Gpnjb0IfNbJC" | |
}, | |
"outputs": [], | |
"source": [ | |
"jnp.mean(global_accs[:,-1000:])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "O16AMqZkNdY_" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install corner" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "-HopzhP-Nf_6" | |
}, | |
"outputs": [], | |
"source": [ | |
"import corner\n", | |
"corner.corner(np.array(chains[:,-2000:]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true | |
}, | |
"id": "ARKRMHOKN8tU" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment