Skip to content

Instantly share code, notes, and snippets.

@rmnldwg
Created June 27, 2024 14:32
Show Gist options
  • Save rmnldwg/ea790ac9fa469a6cd51613c94aa005a9 to your computer and use it in GitHub Desktop.
Save rmnldwg/ea790ac9fa469a6cd51613c94aa005a9 to your computer and use it in GitHub Desktop.
Performance testing script to debug some weird issues that sometimes occur with the `lymph` package.
"""Test different setting's impact on model's performance."""
import argparse
from pathlib import Path
from typing import Any
import emcee
import numpy as np
from multiprocess import Pool
from scipy.special import factorial
import lymph
import lymph.modalities
from lymph.types import ParamsType
MAX_TIME = 15
PARAMS = {
"growth": 0.5,
"II_growth": 0.7,
"TtoII_spread": 0.24,
"TtoIII_spread": 0.03,
"TtoIV_spread": 0.2,
"IItoIII_spread": 0.18,
"IIItoIV_spread": 0.18,
"late_p": 0.6,
}
GRAPH_DICT = {
("tumor", "T"): ["II", "III", "IV"],
("lnl", "II"): ["III"],
("lnl", "III"): ["IV"],
("lnl", "IV"): [],
}
def binom_pmf(k: np.ndarray, n: int, p: float):
"""Binomial PMF."""
if p > 1.0 or p < 0.0:
raise ValueError("Binomial prob must be btw. 0 and 1")
q = 1.0 - p
binom_coeff = factorial(n) / (factorial(k) * factorial(n - k))
return binom_coeff * p**k * q ** (n - k)
def late_binomial(support: np.ndarray, p: float = 0.5) -> np.ndarray:
"""Parametrized binomial distribution."""
return binom_pmf(support, n=support[-1], p=p)
def setup_trinary_model():
"""Set up trinary model."""
model = lymph.models.Unilateral.trinary(graph_dict=GRAPH_DICT, max_time=MAX_TIME)
model.set_modality(name="path", spec=1, sens=1, kind="pathological")
model.set_modality(name="dc", spec=0.94, sens=1, kind="clinical")
model.set_distribution(
t_stage="early",
distribution=binom_pmf(np.arange(MAX_TIME + 1), MAX_TIME, 0.3),
)
model.set_distribution(
t_stage="late",
distribution=late_binomial,
)
return model
def int_or_none(value: Any) -> int | None:
"""Convert string to int or None."""
try:
value = int(value)
except ValueError:
return None
return None if value <= 0 else value
def create_parser():
"""Create parser."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--num-patients",
type=int,
default=50,
help="Number of patients to generate",
)
parser.add_argument(
"--max-steps",
type=int,
default=10000,
help="Number of steps to run",
)
parser.add_argument(
"--use-global",
action="store_true",
default=False,
help=(
"Use global model. NOT using this with `multiprocess` "
"should make code brutally slow"
),
)
parser.add_argument(
"--use-pools",
action="store_true",
default=False,
help=(
"Use multiprocessing. Should make everything faster, "
"as long as you `--use-global`."
),
)
return parser
GLOBAL_MODEL = setup_trinary_model()
def global_log_prob_fn(theta: ParamsType) -> float:
"""Global log probability function."""
return GLOBAL_MODEL.likelihood(given_params=theta, log=True)
def main() -> None:
"""Run main function."""
parser = create_parser()
args = parser.parse_args()
# Set up local model and generate synthetic data
local_model = setup_trinary_model()
local_model.set_params(**PARAMS)
generated = local_model.draw_patients(args.num_patients, [1, 0], seed=13)
# Load data into both the local and the global model
local_model.load_patient_data(generated, side="ipsi", mapping=lambda x: x)
global GLOBAL_MODEL
GLOBAL_MODEL.load_patient_data(generated, side="ipsi", mapping=lambda x: x)
# Delete the HDF5 file if it exists
if (h5file := Path("trinary.hdf5")).exists():
h5file.unlink()
# initialize backend and sampling stuff
backend = emcee.backends.HDFBackend(filename="trinary.hdf5", name="artificial")
ndim = len(PARAMS)
nwalkers = 16 * ndim
initial_state = np.random.uniform(size=(nwalkers, ndim))
def local_log_prob_fn(theta: ParamsType) -> float:
return local_model.likelihood(given_params=theta, log=True)
sampler_kwargs = {
"nwalkers": nwalkers,
"ndim": ndim,
# use global or local log prob function
"log_prob_fn": global_log_prob_fn if args.use_global else local_log_prob_fn,
"backend": backend,
"parameter_names": list(PARAMS.keys()),
}
run_kwargs = {
"initial_state": initial_state,
"nsteps": args.max_steps,
"progress": True,
}
# run sampling with or without multiprocessing
if args.use_pools:
with Pool() as pool:
sampler = emcee.EnsembleSampler(pool=pool, **sampler_kwargs)
sampler.run_mcmc(**run_kwargs)
else:
sampler = emcee.EnsembleSampler(**sampler_kwargs)
sampler.run_mcmc(**run_kwargs)
# use best sample and compare it to true values
samples = sampler.get_chain(flat=True)
best_idx = np.argmax(sampler.get_log_prob(flat=True))
params_dict = dict(zip(PARAMS.keys(), samples[best_idx], strict=True))
local_model.set_params(**params_dict)
print("Mean accept. frac.:", np.mean(sampler.acceptance_fraction))
for (name, target), result in zip(PARAMS.items(), samples.mean(axis=0), strict=True):
print(f"{name}: {result - target:.1%}")
if __name__ == "__main__":
main()
# This file was autogenerated by uv via the following command:
# uv pip compile -
lymph-model @ git+https://github.com/rmnldwg/lymph@67ecfd87d5a9ba483c9c8092e20adc4bfecba25e
# via -r -
cachetools==5.3.3
# via lymph-model
dill==0.3.8
# via multiprocess
emcee==3.1.6
h5py==3.11.0
multiprocess==0.70.16
numpy==2.0.0
# via
# emcee
# h5py
# lymph-model
# pandas
# scipy
pandas==2.2.2
# via lymph-model
python-dateutil==2.9.0.post0
# via pandas
pytz==2024.1
# via pandas
scipy==1.14.0
six==1.16.0
# via python-dateutil
tqdm==4.66.4
tzdata==2024.1
# via pandas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment