-
-
Save alexander-held/b020445d089c3ef79abe93d63543c5e7 to your computer and use it in GitHub Desktop.
analytic staterror optimization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "bd7c82f4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"\n", | |
"import cabinetry\n", | |
"import iminuit\n", | |
"import numpy as np\n", | |
"import pyhf\n", | |
"\n", | |
"cabinetry.set_logging()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2442e330", | |
"metadata": {}, | |
"source": [ | |
"### define a simple 2-bin model for testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "72a9dc40", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO - pyhf.workspace - Validating spec against schema: workspace.json\n", | |
"INFO - pyhf.pdf - Validating spec against schema: model.json\n", | |
"INFO - pyhf.pdf - adding modifier Signal_norm (1 new nuisance parameters)\n", | |
"INFO - pyhf.pdf - adding modifier staterror_SR (2 new nuisance parameters)\n" | |
] | |
} | |
], | |
"source": [ | |
"spec = {\n", | |
" \"channels\": [\n", | |
" {\n", | |
" \"name\": \"SR\",\n", | |
" \"samples\": [\n", | |
" {\n", | |
" \"data\": [25.0, 40.0],\n", | |
" \"modifiers\": [\n", | |
" {\"data\": [2.0, 4.0], \"name\": \"staterror_SR\", \"type\": \"staterror\"},\n", | |
" {\"data\": None, \"name\": \"Signal_norm\", \"type\": \"normfactor\"}\n", | |
" ],\n", | |
" \"name\": \"Signal\"\n", | |
" }\n", | |
" ]\n", | |
" }\n", | |
" ],\n", | |
" \"measurements\": [\n", | |
" {\n", | |
" \"config\": {\n", | |
" \"parameters\": [{\"name\": \"staterror_SR\", \"auxdata\": [1.2, 1.0]}],\n", | |
" \"poi\": \"Signal_norm\"\n", | |
" },\n", | |
" \"name\": \"minimal_example\"\n", | |
" }\n", | |
" ],\n", | |
" \"observations\": [\n", | |
" {\n", | |
" \"data\": [28.0, 37.0],\n", | |
" \"name\": \"SR\"\n", | |
" }\n", | |
" ],\n", | |
" \"version\": \"1.0.0\"\n", | |
"}\n", | |
"\n", | |
"model, data = cabinetry.model_utils.model_and_data(spec)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "db4c73cd", | |
"metadata": {}, | |
"source": [ | |
"### derivation of analytical optimization\n", | |
"\n", | |
"**Poisson term:** $$P(x|r) = \\frac{e^{-r}r^x}{x!}$$\n", | |
"\n", | |
"- $r \\propto \\gamma$ per bin -> let $r = \\gamma m$ with $m$ being the model prediction without gammas\n", | |
"- x is constant (observation in bin)\n", | |
"\n", | |
"derivative: $$\\frac{d}{d \\gamma} \\ln P(x|\\gamma m) = \\frac{x}{\\gamma} - m$$\n", | |
"\n", | |
"\n", | |
"**Gaussian term for $\\gamma$ (with aux $a$):** $$G(a | \\gamma, \\sigma) = \\frac{1}{\\sqrt{2\\pi \\sigma^2}} e^{-\\frac{1}{2}\\frac{(a-\\gamma)^2}{\\sigma^2}}$$\n", | |
"\n", | |
"where\n", | |
"- $\\gamma$ behaves like the parameter in the fit (centered around 1)\n", | |
"- $\\sigma$ is the uncertainty per bin: `model.config.param_set(\"staterror_SR\").width()`\n", | |
"\n", | |
"Gaussian derivative for gamma: $$\\frac{d}{d\\gamma} \\ln G(a | \\gamma, \\sigma) = \\frac{a-\\gamma}{\\sigma^2}$$\n", | |
"\n", | |
"-> makes sense, derivative is zero for $\\gamma = 1$ (no pull)\n", | |
"\n", | |
"\n", | |
"**Putting everything together:** $$ \\ln L(x, a | \\gamma) = \\ln P(x|r) + \\ln G(a | \\gamma, \\sigma)$$\n", | |
"which will be maximized at $$\\frac{d}{d\\gamma} \\ln L(x, a | \\gamma) = 0 $$\n", | |
"implying that\n", | |
"\n", | |
"$$\\frac{a-\\gamma}{\\sigma^2} + \\frac{x}{\\gamma} - m = 0$$\n", | |
"\n", | |
"which can also be rewritten as $$\\gamma^2 + (m\\sigma^2 -a)\\gamma - x\\sigma^2 = 0$$\n", | |
"reproducing eq 11 in https://arxiv.org/abs/1103.0354\n", | |
"\n", | |
"-> solve quadratic equation for $\\gamma$\n", | |
"\n", | |
"$$\\gamma = \\frac{1}{2}(\\pm \\sqrt{a^2-2am\\sigma^2+m^2\\sigma^4+4\\sigma^2x}+a-m\\sigma^2)$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c97845c7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"analytical solution: [1.1905221845696083, 0.9782329983125267]\n" | |
] | |
} | |
], | |
"source": [ | |
"def get_staterror_aux(model):\n", | |
" \"\"\"returns the auxdata values for staterror terms\"\"\"\n", | |
" # TODO: can it ever happen that this is NOT correctly ordered?\n", | |
" # how could this be done in a way that doesn't implicitly rely\n", | |
" # on the parameter order matching the bin order?\n", | |
" stat_aux = []\n", | |
" for par in model.config.par_order:\n", | |
" if dict(model.config.modifiers)[par] == \"staterror\":\n", | |
" stat_aux += model.config.param_set(par).auxdata\n", | |
" return stat_aux\n", | |
"\n", | |
"\n", | |
"def best_gammas(model, data, pars):\n", | |
" \"\"\"analytical solution for best-fit staterror values\"\"\"\n", | |
" best_gammas = []\n", | |
" for i_bin in range(sum(model.config.channel_nbins.values())):\n", | |
" s = model.config.param_set(\"staterror_SR\").width()[i_bin] # sigma\n", | |
" x = data[i_bin] # obs data\n", | |
" m = model.main_model.expected_data(pars)[i_bin] # model pred\n", | |
" a = get_staterror_aux(model)[i_bin] # aux data\n", | |
"\n", | |
" # TODO: this assumes that all samples get scaled by the gamma\n", | |
"\n", | |
" # solve the quadratic equation\n", | |
" sol_1 = 1/2*((a**2-2*a*m*s**2+m**2*s**4+4*s**2*x)**0.5 + a - m*s**2)\n", | |
" sol_2 = 1/2*(-(a**2-2*a*m*s**2+m**2*s**4+4*s**2*x)**0.5 + a - m*s**2)\n", | |
" assert sol_2 < 0 # ensure that the first solution is the only physically sensible one\n", | |
" best_gammas.append(sol_1)\n", | |
" return best_gammas\n", | |
"\n", | |
"\n", | |
"print(f\"analytical solution: {best_gammas(model, data, model.config.suggested_init())}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "72c4cfdc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO - cabinetry.fit - performing maximum likelihood fit\n", | |
"INFO - cabinetry.fit - Migrad status:\n", | |
"┌─────────────────────────────────────────────────────────────────────────┐\n", | |
"│ Migrad │\n", | |
"├──────────────────────────────────┬──────────────────────────────────────┤\n", | |
"│ FCN = 4.934 │ Nfcn = 44 │\n", | |
"│ EDM = 3.78e-10 (Goal: 2e-08) │ │\n", | |
"├──────────────────────────────────┼──────────────────────────────────────┤\n", | |
"│ Valid Minimum │ No Parameters at limit │\n", | |
"├──────────────────────────────────┼──────────────────────────────────────┤\n", | |
"│ Below EDM threshold (goal x 10) │ Below call limit │\n", | |
"├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", | |
"│ Covariance │ Hesse ok │ Accurate │ Pos. def. │ Not forced │\n", | |
"└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘\n", | |
"DEBUG - cabinetry.fit - -2 log(L) = 4.934471 at best-fit point\n", | |
"INFO - cabinetry.fit - fit results (with symmetric uncertainties):\n", | |
"INFO - cabinetry.fit - Signal_norm = 1.0000 +/- 0.0000 (constant)\n", | |
"INFO - cabinetry.fit - staterror_SR[0] = 1.1905 +/- 0.0754\n", | |
"INFO - cabinetry.fit - staterror_SR[1] = 0.9782 +/- 0.0849\n" | |
] | |
} | |
], | |
"source": [ | |
"fit_results = cabinetry.fit.fit(model, data, fix_pars=[True, False, False]) # fix POI here for a test" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "770d8070", | |
"metadata": {}, | |
"source": [ | |
"### define helpers and objective function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "8e3ee3d6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# define some helper functions\n", | |
"\n", | |
"def staterror_indices(model):\n", | |
" \"\"\"indices of staterror parameters in a model\"\"\"\n", | |
" indices = []\n", | |
" for par in model.config.par_order:\n", | |
" if dict(model.config.modifiers)[par] == \"staterror\":\n", | |
" idx_slice = model.config.par_map[par][\"slice\"]\n", | |
" indices += [idx for idx in range(idx_slice.start, idx_slice.stop)]\n", | |
"\n", | |
" # TODO: handle staterrors not present in some channels / partially pruned within channels\n", | |
" assert len(indices) == sum(model.config.channel_nbins.values())\n", | |
" return indices\n", | |
"\n", | |
"\n", | |
"def fix_staterror_pars(model, default_fixed):\n", | |
" \"\"\"fix all staterrors\"\"\"\n", | |
" indices_to_fix = staterror_indices(model)\n", | |
" for idx in indices_to_fix:\n", | |
" default_fixed[idx] = True\n", | |
" return default_fixed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1b59e13d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"6.17239716466591" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def build_twice_nll(model, data):\n", | |
" \"\"\"returns objective function that automatically optimizes staterror parameters\"\"\"\n", | |
" def twice_nll_func(pars):\n", | |
" \"\"\"objective function to hand to minimizer\"\"\"\n", | |
" main_data, aux_data = model.fullpdf_tv.split(pyhf.tensorlib.astensor(data))\n", | |
"\n", | |
" # get best gammas for current configuration (using model prediction NOT scaled by gammas)\n", | |
" gammas = best_gammas(model, data, pars)\n", | |
" for i, idx in enumerate(staterror_indices(model)):\n", | |
" pars[idx] = gammas[i] # insert optimal gamma values into parameters\n", | |
"\n", | |
" poisson_ll = pyhf.tensorlib.to_numpy(\n", | |
" sum(pyhf.tensorlib.poisson_dist(model.main_model.expected_data(pars)).log_prob(main_data))\n", | |
" )\n", | |
" constraint_ll = pyhf.tensorlib.to_numpy(\n", | |
" model.constraint_logpdf(aux_data, pyhf.tensorlib.astensor(pars))\n", | |
" )\n", | |
" return -2 * (poisson_ll + constraint_ll)\n", | |
" return twice_nll_func\n", | |
"\n", | |
"\n", | |
"# test this with a call\n", | |
"build_twice_nll(model, data)([1.1, 1.0, 1.0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9123681a", | |
"metadata": {}, | |
"source": [ | |
"### use this in a fit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "b166baee", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"# function calls: 23 (including HESSE)\n", | |
"best-fit results (for all parameters other than staterror):\n", | |
" Signal_norm: 0.92898\n", | |
" staterror_SR[0]: 1.00000\n", | |
" staterror_SR[1]: 1.00000\n", | |
"optimal gammas [1.200619191506972, 0.9988377254391094]\n" | |
] | |
} | |
], | |
"source": [ | |
"# standard iminuit fit with custom objective function that does the analytical optimization internally\n", | |
"def fit_with_analytical_staterror(model):\n", | |
" \"\"\"MLE with analytical staterror optimization\"\"\"\n", | |
" init_pars = model.config.suggested_init()\n", | |
" labels = model.config.par_names\n", | |
"\n", | |
" m = iminuit.Minuit(build_twice_nll(model, data), init_pars, name=labels)\n", | |
" m.fixed = fix_staterror_pars(model, model.config.suggested_fixed()) # fix staterror parameters manually\n", | |
" m.errordef = 1\n", | |
" m.print_level = 1\n", | |
"\n", | |
" m.migrad()\n", | |
" m.hesse() # TODO: should probably let the staterror parameters float again here?\n", | |
"\n", | |
" # TODO: could insert the best-fit staterror values here at the end into the fit result probably\n", | |
"\n", | |
" print(f\"# function calls: {m.nfcn} (including HESSE)\")\n", | |
" print(\"best-fit results (for all parameters other than staterror):\")\n", | |
" for val, name in zip(np.asarray(m.values), model.config.par_names):\n", | |
" print(f\" {name}: {val:.5f}\")\n", | |
"\n", | |
" return np.asarray(m.values)\n", | |
"\n", | |
"\n", | |
"best_fit_params = fit_with_analytical_staterror(model)\n", | |
"\n", | |
"# the best-fit results for staterrors need to be evaluated again as they were not propagated\n", | |
"print(\"optimal gammas\", best_gammas(model, data, best_fit_params))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "2671eeda", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO - cabinetry.fit - performing maximum likelihood fit\n", | |
"INFO - cabinetry.fit - Migrad status:\n", | |
"┌─────────────────────────────────────────────────────────────────────────┐\n", | |
"│ Migrad │\n", | |
"├──────────────────────────────────┬──────────────────────────────────────┤\n", | |
"│ FCN = 4.65 │ Nfcn = 74 │\n", | |
"│ EDM = 1.34e-08 (Goal: 2e-08) │ │\n", | |
"├──────────────────────────────────┼──────────────────────────────────────┤\n", | |
"│ Valid Minimum │ No Parameters at limit │\n", | |
"├──────────────────────────────────┼──────────────────────────────────────┤\n", | |
"│ Below EDM threshold (goal x 10) │ Below call limit │\n", | |
"├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", | |
"│ Covariance │ Hesse ok │ Accurate │ Pos. def. │ Not forced │\n", | |
"└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘\n", | |
"DEBUG - cabinetry.fit - -2 log(L) = 4.649501 at best-fit point\n", | |
"INFO - cabinetry.fit - fit results (with symmetric uncertainties):\n", | |
"INFO - cabinetry.fit - Signal_norm = 0.9290 +/- 0.1290\n", | |
"INFO - cabinetry.fit - staterror_SR[0] = 1.2006 +/- 0.0776\n", | |
"INFO - cabinetry.fit - staterror_SR[1] = 0.9988 +/- 0.0933\n" | |
] | |
} | |
], | |
"source": [ | |
"# compare to standard MLE with everything floating\n", | |
"_ = cabinetry.fit.fit(model, data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4ecd819c", | |
"metadata": {}, | |
"source": [ | |
"-> there are more function calls in this setup (due to MC stat parameters also getting fitted), but best-fit values for all parameters match perfectly" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "793034cc", | |
"metadata": {}, | |
"source": [ | |
"### some sanity checks used for debugging" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "3263b8f6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"aux [1.2, 1.0]\n", | |
"width [0.08, 0.1]\n", | |
"manual constraint term: 0.20918667089295906\n", | |
"pyhf: 0.20918667089295961\n" | |
] | |
} | |
], | |
"source": [ | |
"# check for Gaussian term\n", | |
"def gauss(x, mu, sigma):\n", | |
" return 1/(2*math.pi*sigma**2)**0.5 * math.exp(-1/2*(x-mu)**2/(sigma)**2)\n", | |
"\n", | |
"\n", | |
"gammas = [1.1, 1.2] # use (arbitrary) gamma values to check\n", | |
"\n", | |
"aux = model.config.auxdata\n", | |
"print(\"aux\", aux)\n", | |
"width = model.config.param_set(\"staterror_SR\").width()\n", | |
"print(\"width\", width)\n", | |
"\n", | |
"# x = auxdata, staterror parameter value = 1.1, 0.9 for this test, 0.08, 0.125 width\n", | |
"constraint_ll = math.log(gauss(aux[0], gammas[0], width[0])) + math.log(gauss(aux[1], gammas[1], width[1]))\n", | |
"print(f\"manual constraint term: {constraint_ll}\")\n", | |
"\n", | |
"pars = [1.0] + gammas\n", | |
"print(f\"pyhf: {model.constraint_logpdf(pyhf.tensorlib.astensor(aux), pyhf.tensorlib.astensor(pars))}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "0922b432", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"manual Poisson term: -2.5928656395163365\n", | |
"pyhf: -2.592865639516333\n" | |
] | |
} | |
], | |
"source": [ | |
"# check for Poisson term\n", | |
"def poiss(x, r):\n", | |
" return math.exp(-r)*(r**x) / math.factorial(x)\n", | |
"\n", | |
"\n", | |
"pred = 25*1.1 # sample * gamma value\n", | |
"obs = 26\n", | |
"\n", | |
"# log Poisson probability makes sense\n", | |
"print(f\"manual Poisson term: {math.log(poiss(obs, pred))}\")\n", | |
"print(f\"pyhf: {pyhf.tensorlib.poisson_dist(pred).log_prob(obs)}\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.16" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment