Created
June 27, 2024 14:32
-
-
Save rmnldwg/ea790ac9fa469a6cd51613c94aa005a9 to your computer and use it in GitHub Desktop.
Performance testing script to debug some weird issues that sometimes occur with the `lymph` package.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Test different setting's impact on model's performance.""" | |
import argparse | |
from pathlib import Path | |
from typing import Any | |
import emcee | |
import numpy as np | |
from multiprocess import Pool | |
from scipy.special import factorial | |
import lymph | |
import lymph.modalities | |
from lymph.types import ParamsType | |
MAX_TIME = 15 | |
PARAMS = { | |
"growth": 0.5, | |
"II_growth": 0.7, | |
"TtoII_spread": 0.24, | |
"TtoIII_spread": 0.03, | |
"TtoIV_spread": 0.2, | |
"IItoIII_spread": 0.18, | |
"IIItoIV_spread": 0.18, | |
"late_p": 0.6, | |
} | |
GRAPH_DICT = { | |
("tumor", "T"): ["II", "III", "IV"], | |
("lnl", "II"): ["III"], | |
("lnl", "III"): ["IV"], | |
("lnl", "IV"): [], | |
} | |
def binom_pmf(k: np.ndarray, n: int, p: float): | |
"""Binomial PMF.""" | |
if p > 1.0 or p < 0.0: | |
raise ValueError("Binomial prob must be btw. 0 and 1") | |
q = 1.0 - p | |
binom_coeff = factorial(n) / (factorial(k) * factorial(n - k)) | |
return binom_coeff * p**k * q ** (n - k) | |
def late_binomial(support: np.ndarray, p: float = 0.5) -> np.ndarray: | |
"""Parametrized binomial distribution.""" | |
return binom_pmf(support, n=support[-1], p=p) | |
def setup_trinary_model(): | |
"""Set up trinary model.""" | |
model = lymph.models.Unilateral.trinary(graph_dict=GRAPH_DICT, max_time=MAX_TIME) | |
model.set_modality(name="path", spec=1, sens=1, kind="pathological") | |
model.set_modality(name="dc", spec=0.94, sens=1, kind="clinical") | |
model.set_distribution( | |
t_stage="early", | |
distribution=binom_pmf(np.arange(MAX_TIME + 1), MAX_TIME, 0.3), | |
) | |
model.set_distribution( | |
t_stage="late", | |
distribution=late_binomial, | |
) | |
return model | |
def int_or_none(value: Any) -> int | None: | |
"""Convert string to int or None.""" | |
try: | |
value = int(value) | |
except ValueError: | |
return None | |
return None if value <= 0 else value | |
def create_parser(): | |
"""Create parser.""" | |
parser = argparse.ArgumentParser(description=__doc__) | |
parser.add_argument( | |
"--num-patients", | |
type=int, | |
default=50, | |
help="Number of patients to generate", | |
) | |
parser.add_argument( | |
"--max-steps", | |
type=int, | |
default=10000, | |
help="Number of steps to run", | |
) | |
parser.add_argument( | |
"--use-global", | |
action="store_true", | |
default=False, | |
help=( | |
"Use global model. NOT using this with `multiprocess` " | |
"should make code brutally slow" | |
), | |
) | |
parser.add_argument( | |
"--use-pools", | |
action="store_true", | |
default=False, | |
help=( | |
"Use multiprocessing. Should make everything faster, " | |
"as long as you `--use-global`." | |
), | |
) | |
return parser | |
GLOBAL_MODEL = setup_trinary_model() | |
def global_log_prob_fn(theta: ParamsType) -> float: | |
"""Global log probability function.""" | |
return GLOBAL_MODEL.likelihood(given_params=theta, log=True) | |
def main() -> None: | |
"""Run main function.""" | |
parser = create_parser() | |
args = parser.parse_args() | |
# Set up local model and generate synthetic data | |
local_model = setup_trinary_model() | |
local_model.set_params(**PARAMS) | |
generated = local_model.draw_patients(args.num_patients, [1, 0], seed=13) | |
# Load data into both the local and the global model | |
local_model.load_patient_data(generated, side="ipsi", mapping=lambda x: x) | |
global GLOBAL_MODEL | |
GLOBAL_MODEL.load_patient_data(generated, side="ipsi", mapping=lambda x: x) | |
# Delete the HDF5 file if it exists | |
if (h5file := Path("trinary.hdf5")).exists(): | |
h5file.unlink() | |
# initialize backend and sampling stuff | |
backend = emcee.backends.HDFBackend(filename="trinary.hdf5", name="artificial") | |
ndim = len(PARAMS) | |
nwalkers = 16 * ndim | |
initial_state = np.random.uniform(size=(nwalkers, ndim)) | |
def local_log_prob_fn(theta: ParamsType) -> float: | |
return local_model.likelihood(given_params=theta, log=True) | |
sampler_kwargs = { | |
"nwalkers": nwalkers, | |
"ndim": ndim, | |
# use global or local log prob function | |
"log_prob_fn": global_log_prob_fn if args.use_global else local_log_prob_fn, | |
"backend": backend, | |
"parameter_names": list(PARAMS.keys()), | |
} | |
run_kwargs = { | |
"initial_state": initial_state, | |
"nsteps": args.max_steps, | |
"progress": True, | |
} | |
# run sampling with or without multiprocessing | |
if args.use_pools: | |
with Pool() as pool: | |
sampler = emcee.EnsembleSampler(pool=pool, **sampler_kwargs) | |
sampler.run_mcmc(**run_kwargs) | |
else: | |
sampler = emcee.EnsembleSampler(**sampler_kwargs) | |
sampler.run_mcmc(**run_kwargs) | |
# use best sample and compare it to true values | |
samples = sampler.get_chain(flat=True) | |
best_idx = np.argmax(sampler.get_log_prob(flat=True)) | |
params_dict = dict(zip(PARAMS.keys(), samples[best_idx], strict=True)) | |
local_model.set_params(**params_dict) | |
print("Mean accept. frac.:", np.mean(sampler.acceptance_fraction)) | |
for (name, target), result in zip(PARAMS.items(), samples.mean(axis=0), strict=True): | |
print(f"{name}: {result - target:.1%}") | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file was autogenerated by uv via the following command: | |
# uv pip compile - | |
lymph-model @ git+https://github.com/rmnldwg/lymph@67ecfd87d5a9ba483c9c8092e20adc4bfecba25e | |
# via -r - | |
cachetools==5.3.3 | |
# via lymph-model | |
dill==0.3.8 | |
# via multiprocess | |
emcee==3.1.6 | |
h5py==3.11.0 | |
multiprocess==0.70.16 | |
numpy==2.0.0 | |
# via | |
# emcee | |
# h5py | |
# lymph-model | |
# pandas | |
# scipy | |
pandas==2.2.2 | |
# via lymph-model | |
python-dateutil==2.9.0.post0 | |
# via pandas | |
pytz==2024.1 | |
# via pandas | |
scipy==1.14.0 | |
six==1.16.0 | |
# via python-dateutil | |
tqdm==4.66.4 | |
tzdata==2024.1 | |
# via pandas |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment