Skip to content

Instantly share code, notes, and snippets.

@kazewong
Created October 6, 2022 03:05
Show Gist options
  • Save kazewong/a912028e4f9e6cdcb8c51765f65241ca to your computer and use it in GitHub Desktop.
Save kazewong/a912028e4f9e6cdcb8c51765f65241ca to your computer and use it in GitHub Desktop.
Notebook for running GW150914 with FlowMC
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "0LM6rDlQI1J2"
},
"source": [
"1. Make sure you set your run time device to GPU\n",
"2. Download the data file from https://drive.google.com/file/d/1Jalw1mJ4_Cvkp_QjpvKIJS1HYA19SgBL/view?usp=sharing\n",
"3. Upload the data to this local environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "g_Lqr_OpQ_OQ"
},
"outputs": [],
"source": [
"!pip install flowMC jaxGW lalsuite\n",
"!pip install --upgrade \"jax[cuda]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"!git clone https://github.com/maxisi/ripple.git\n",
"%cd ripple\n",
"!pip install ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "0sX3Ja6wRJfd"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax.numpy as jnp\n",
"import jax\n",
"from lal import GreenwichMeanSiderealTime\n",
"\n",
"from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_polar\n",
"from jaxgw.PE.detector_preset import * \n",
"from jaxgw.PE.heterodyneLikelihood import make_heterodyne_likelihood\n",
"from jaxgw.PE.detector_projection import make_detector_response\n",
"\n",
"from flowMC.nfmodel.rqSpline import RQSpline\n",
"from flowMC.sampler.MALA import make_mala_sampler, mala_sampler_autotune\n",
"from flowMC.sampler.Sampler import Sampler\n",
"from flowMC.utils.PRNG_keys import initialize_rng_keys\n",
"from flowMC.nfmodel.utils import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "Z5pm-3DcaaNx"
},
"outputs": [],
"source": [
"\n",
"data = np.load('/content/GW150914_data.npz',allow_pickle=True)\n",
"\n",
"minimum_frequency = data['minimum_frequency']\n",
"\n",
"H1_frequency = data['frequency'].tolist()['H1']\n",
"H1_data = data['data'].tolist()['H1'][H1_frequency>minimum_frequency]\n",
"H1_psd = data['psd'].tolist()['H1'][H1_frequency>minimum_frequency]\n",
"H1_frequency = H1_frequency[H1_frequency>minimum_frequency]\n",
"\n",
"L1_frequency = data['frequency'].tolist()['L1']\n",
"L1_data = data['data'].tolist()['L1'][L1_frequency>minimum_frequency]\n",
"L1_psd = data['psd'].tolist()['L1'][L1_frequency>minimum_frequency]\n",
"L1_frequency = L1_frequency[L1_frequency>minimum_frequency]\n",
"\n",
"H1 = get_H1()\n",
"H1_response = make_detector_response(H1[0], H1[1])\n",
"L1 = get_L1()\n",
"L1_response = make_detector_response(L1[0], L1[1])\n",
"\n",
"trigger_time = 1126259462.4\n",
"duration = 4 \n",
"post_trigger_duration = 2\n",
"epoch = duration - post_trigger_duration\n",
"gmst = GreenwichMeanSiderealTime(trigger_time)\n",
"f_ref = 20\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "G6cm0OTiJIcQ"
},
"outputs": [],
"source": [
"\n",
"ref_param = jnp.array([ 3.14562104e+01, 2.49863038e-01, 3.01633871e-02, 1.44865987e-02,\n",
" 5.41706887e+02, -1.42249470e-02, 1.90571654e+00, 3.02521496e+00,\n",
" 1.13895265e+00, 1.80359165e+00, -1.28078777e+00])\n",
"\n",
"\n",
"from jaxgw.PE.heterodyneLikelihood import make_heterodyne_likelihood_mutliple_detector\n",
"\n",
"data_list = [H1_data, L1_data]\n",
"psd_list = [H1_psd, L1_psd]\n",
"response_list = [H1_response, L1_response]\n",
"\n",
"logL = make_heterodyne_likelihood_mutliple_detector(data_list, psd_list, response_list, gen_IMRPhenomD_polar, ref_param, H1_frequency, gmst, epoch, f_ref, 101)\n",
"\n",
"\n",
"n_dim = 11\n",
"n_chains = 300\n",
"n_loop_training = 50\n",
"n_loop_production = 20\n",
"n_local_steps = 200\n",
"n_global_steps = 200\n",
"learning_rate = 0.001\n",
"max_samples = 50000\n",
"momentum = 0.9\n",
"num_epochs = 60\n",
"batch_size = 50000\n",
"\n",
"guess_param = ref_param\n",
"\n",
"guess_param = np.array(jnp.repeat(guess_param[None,:],int(n_chains),axis=0)*np.random.normal(loc=1,scale=0.1,size=(int(n_chains),n_dim)))\n",
"guess_param[guess_param[:,1]>0.25,1] = 0.249\n",
"guess_param[:,6] = (guess_param[:,6]%(2*jnp.pi))\n",
"guess_param[:,7] = (guess_param[:,7]%(jnp.pi))\n",
"guess_param[:,8] = (guess_param[:,8]%(jnp.pi))\n",
"guess_param[:,9] = (guess_param[:,9]%(2*jnp.pi))\n",
"\n",
"\n",
"print(\"Preparing RNG keys\")\n",
"rng_key_set = initialize_rng_keys(n_chains, seed=42)\n",
"\n",
"print(\"Initializing MCMC model and normalizing flow model.\")\n",
"\n",
"prior_range = jnp.array([[10,80],[0.10,0.25],[0,1],[0,1],[0,2000],[-0.1,0.1],[0,2*np.pi],[0,np.pi],[0,np.pi],[0,2*np.pi],[-jnp.pi/2,jnp.pi/2]])\n",
"\n",
"initial_position = jax.random.uniform(rng_key_set[0], shape=(int(n_chains), n_dim)) * 1\n",
"for i in range(n_dim):\n",
" initial_position = initial_position.at[:,i].set(initial_position[:,i]*(prior_range[i,1]-prior_range[i,0])+prior_range[i,0])\n",
"\n",
"# initial_position = initial_position.at[:,0].set(guess_param[:,0])\n",
"# initial_position = initial_position.at[:,1].set(guess_param[:,1])\n",
"\n",
"def top_hat(x):\n",
" output = 0.\n",
" for i in range(n_dim):\n",
" output = jax.lax.cond(x[i]>=prior_range[i,0], lambda: output, lambda: -jnp.inf)\n",
" output = jax.lax.cond(x[i]<=prior_range[i,1], lambda: output, lambda: -jnp.inf)\n",
" return output\n",
"\n",
"def posterior(theta):\n",
" prior = top_hat(theta)\n",
" return logL(theta) + prior\n",
"\n",
"model = RQSpline(n_dim, 6, [64,64], 8)\n",
"\n",
"print(\"Initializing sampler class\")\n",
"\n",
"posterior = posterior\n",
"\n",
"mass_matrix = jnp.eye(n_dim)\n",
"mass_matrix = mass_matrix.at[1,1].set(1e-3)\n",
"mass_matrix = mass_matrix.at[5,5].set(1e-3)\n",
"\n",
"local_sampler_caller = lambda x: make_mala_sampler(x, jit=True)\n",
"sampler_params = {'dt':mass_matrix*3e-3}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "uw06jYnFJo0I"
},
"outputs": [],
"source": [
"print(\"Running sampler\")\n",
"\n",
"nf_sampler = Sampler(\n",
" n_dim,\n",
" rng_key_set,\n",
" local_sampler_caller,\n",
" sampler_params,\n",
" posterior,\n",
" model,\n",
" n_loop_training=n_loop_training,\n",
" n_loop_production = n_loop_production,\n",
" n_local_steps=n_local_steps,\n",
" n_global_steps=n_global_steps,\n",
" n_chains=n_chains,\n",
" n_epochs=num_epochs,\n",
" learning_rate=learning_rate,\n",
" momentum=momentum,\n",
" batch_size=batch_size,\n",
" use_global=True,\n",
" keep_quantile=0.,\n",
")\n",
"\n",
"nf_sampler.sample(initial_position)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "xhSYIpz9KAMM"
},
"outputs": [],
"source": [
"prod = nf_sampler.get_sampler_state(training=True)\n",
"chains, log_prob, local_accs, global_accs, loss = prod.values()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "Gpnjb0IfNbJC"
},
"outputs": [],
"source": [
"jnp.mean(global_accs[:,-1000:])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "O16AMqZkNdY_"
},
"outputs": [],
"source": [
"!pip install corner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "-HopzhP-Nf_6"
},
"outputs": [],
"source": [
"import corner\n",
"corner.corner(np.array(chains[:,-2000:]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "ARKRMHOKN8tU"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment