Skip to content

Instantly share code, notes, and snippets.

@jeetsukumaran
Last active August 29, 2015 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeetsukumaran/9217140 to your computer and use it in GitHub Desktop.
Save jeetsukumaran/9217140 to your computer and use it in GitHub Desktop.
Generate configuration files for Vaughn's MASTER to generate viral phylogenies under different phylodynamic regimes
#! /usr/bin/env python
__version__ = "1.0.0"
from time import strftime
import random
import sys
import os
import argparse
SIMULATION_TIME_XML_TEMPLATE = """ simulationTime='{simulation_time}'"""
SIMULATION_RANDOM_SEED_XML_TEMPLATE = """ seed='{random_seed}'"""
POPULATION_END_CONDITION_XML_TEMPLATE = """\
<populationEndCondition spec='PopulationEndCondition'
threshold="{population_end_condition_threshold}"
exceedCondition="false"
isRejection="false">
<population spec='Population' idref='N'/>
</populationEndCondition>"""
REACTION_XML_TEMPLATE = """\
<reaction spec='InheritanceReaction' {reaction_name} rate="{reaction_rate}">
{reaction_description}
</reaction> """
REACTION_GROUP_XML_TEMPLATE = """\
<reactionGroup spec='InheritanceReactionGroup' reactionGroupName="{reaction_group_name}">
{reactions}
</reactionGroup>"""
MASTER_XML_TEMPLATE = """\
{remarks}
<beast version='2.0' namespace='master.beast:beast.core.parameter'>
<run spec='InheritanceTrajectory'
{simulation_time}
samplePopulationSizes="true"
sampleAtNodesOnly="true"
{random_seed}
verbosity="1">
<model spec='InheritanceModel' id='model'>
<population spec='Population' id='N' populationName='N'/>
<populationType spec='PopulationType' typeName='I' dim='{num_infected_host_types}' id='I'/>
<populationType spec='PopulationType' typeName='U' dim='{num_uninfected_host_types}' id='U'/>
{reaction_directives}
</model>
<initialState spec='InitState'>
<populationSize spec='PopulationSize' population='@N' size='{initial_naive_population_size}'/>
<lineageSeed spec='Individual' time="0.0">
<population spec='Population' type='@I' location="0"/>
</lineageSeed>
</initialState>
{population_end_condition}
<output spec='NewickOutput' fileName='{output_prefix}.newick'/>
<output spec='NexusOutput' fileName='{output_prefix}.nexus'/>
<output spec='JsonOutput' fileName='{output_prefix}.json'/>
</run>
</beast>
"""
JOB_TEMPLATE = """\
#! /bin/bash
#$ -cwd
#$ -V
#$ -S /bin/bash
#$ -l mem_free=3072M
java -Xms3072m -Xmx3072m -jar {jar_path} {xml_path}"""
script_dir = os.path.abspath(os.path.dirname(__file__))
jar_dir = os.path.join(script_dir, os.pardir, "MASTER-1.0")
jar_path = os.path.abspath(os.path.join(jar_dir, "MASTER-1.0.jar"))
JOB_TEMPLATE = JOB_TEMPLATE.format(jar_path=jar_path, xml_path="{xml_path}")
def number_to_base_str(num, base, numerals="0123456789abcdefghijklmnopqrstuvwxyz"):
return ((num == 0) and numerals[0]) or (number_to_base_str(num // base, base, numerals).lstrip(numerals[0]) + numerals[num % base])
def bitmask_to_str(b, width):
return "{0:0>{1}b}".format(b, width)
class StrainTransitionReaction(object):
def __init__(self,
progenitor,
product,
adjustment_factor):
self.progenitor = progenitor
self.product = product
self.adjustment_factor = adjustment_factor
self.realized_reaction_rate = None
self.reaction_description_template = "{}:1 -> {}:1"
self.reaction_name = None
def calculate_realized_reaction_rate(self, normalized_base_rate):
self.realized_reaction_rate = (
normalized_base_rate * self.adjustment_factor
)
return self.realized_reaction_rate
def xml(self, normalized_base_rate):
reaction_rate = self.calculate_realized_reaction_rate(normalized_base_rate)
reaction_description = self.reaction_description_template.format(
self.progenitor.label,
self.product.label,
)
if self.reaction_name:
reaction_name = "reactionName = '{}'".format(self.reaction_name)
else:
reaction_name = ""
result = REACTION_XML_TEMPLATE.format(
reaction_name=reaction_name,
reaction_rate=reaction_rate,
reaction_description=reaction_description)
return result
class NewStrainReaction(StrainTransitionReaction):
def __init__(self,
progenitor,
product,
adjustment_factor):
StrainTransitionReaction.__init__(self,
progenitor=progenitor,
product=product,
adjustment_factor=adjustment_factor)
self.reaction_name = "Mutation_{}_to_{}".format(progenitor.infection_profile, product.infection_profile)
class RecoveryReaction(StrainTransitionReaction):
def __init__(self,
progenitor,
product,
adjustment_factor):
StrainTransitionReaction.__init__(self,
progenitor=progenitor,
product=product,
adjustment_factor=adjustment_factor)
self.reaction_description_template = "{} -> {}"
self.reaction_name = "Recovery_{}_to_{}".format(progenitor.infection_profile, product.infection_profile)
class MortalityReaction(StrainTransitionReaction):
def __init__(self,
progenitor,
product,
adjustment_factor):
StrainTransitionReaction.__init__(self,
progenitor=progenitor,
product=product,
adjustment_factor=adjustment_factor)
class InfectionReaction(object):
def __init__(self,
source,
recipient,
product,
multiple_potential_strain_in_source_adjustment_factor,
cross_susceptibility_factor):
self.source = source
self.recipient = recipient
self.product = product
self.multiple_potential_strain_in_source_adjustment_factor = multiple_potential_strain_in_source_adjustment_factor
self.cross_susceptibility_factor = cross_susceptibility_factor
self.realized_reaction_rate = None
def calculate_realized_reaction_rate(self, normalized_base_infection_rate):
self.realized_reaction_rate = (
normalized_base_infection_rate
* self.multiple_potential_strain_in_source_adjustment_factor
* self.cross_susceptibility_factor
)
return self.realized_reaction_rate
def xml(self, normalized_base_infection_rate):
reaction_rate = self.calculate_realized_reaction_rate(normalized_base_infection_rate)
reaction_description = "{}:1 + {}:2 -> {}:1 + {}:1".format(
self.source.label,
self.recipient.label,
self.source.label,
self.product.label,
)
reaction_name = "reactionName='Infection_{}_{}_{}'".format(self.source.infection_profile, self.recipient.infection_profile, self.product.infection_profile)
result = REACTION_XML_TEMPLATE.format(
reaction_name=reaction_name,
reaction_rate=reaction_rate,
reaction_description=reaction_description)
return result
class HostType(object):
def __init__(self,
immunity_profile,
infection_profile,
subpop_id):
self.immunity_profile = immunity_profile
self.infection_profile = infection_profile
self.subpop_id = subpop_id
if "1" in self.infection_profile:
self.label = "I[{}]".format(self.subpop_id) # infected
elif "1" in self.immunity_profile:
self.label = "U[{}]".format(self.subpop_id) # uninfected
else:
self.label = "N" # naive (not exposed)
class Epidemiology(object):
num_exposure_states = 3
def __init__(self, max_strains):
self.max_strains = max_strains
self.global_infection_rate = 0.2
self.global_recovery_rate = 0.2
self.global_mortality_rate = 0.2
self.strain_mutation_rate = 0.2
self.cross_immunity_factor = 0.2
self.cross_immunity_model = "distance"
self.initial_naive_population_size = 1000
self.simulation_run_time = 100
self.naive_population_proportion_termination_threshold = None
# self.naive_population_proportion_termination_threshold = 1-self.global_infection_rate
self.rng = random.Random()
self.seed_simulator_rng = False
self.num_uninfected_host_types = None
self.num_infected_host_types = None
self.host_types = []
self.profile_host_map = {}
def build_host_types(self):
self.host_types = []
self.profile_host_map = {}
self.num_uninfected_host_types = 0
self.num_infected_host_types = 0
for immunity_bitmask in range(2**self.max_strains):
immunity_profile = self.bitmask_to_str(immunity_bitmask)
if immunity_bitmask == 0:
host = HostType(
immunity_profile=immunity_profile,
infection_profile=immunity_profile,
subpop_id=None)
self.host_types.append(host)
else:
for infection_bitmask in range(2**self.max_strains):
infection_profile = self.bitmask_to_str(infection_bitmask)
skip = False
for idx, immunity_strain in enumerate(immunity_profile):
if immunity_strain == "0" and infection_profile[idx] == "1":
skip = True
break
if skip:
continue
if "1" in infection_profile:
subpop_id = self.num_infected_host_types
self.num_infected_host_types += 1
else:
subpop_id = self.num_uninfected_host_types
self.num_uninfected_host_types += 1
host = HostType(
immunity_profile=immunity_profile,
infection_profile=infection_profile,
subpop_id=subpop_id)
self.host_types.append(host)
self.profile_host_map[(immunity_profile, infection_profile)] = host
# for host in self.host_types:
# print "{} {} : {}".format(host.immunity_profile, host.infection_profile, host.label)
# sys.exit(0)
def generate_xml_template(self):
if not self.host_types:
self.build_host_types()
remarks = []
remarks.append("MASTER configuration file for phylodynamics simulations")
remarks.append("Generated by PhylodynaMaster v{} on {}".format(__version__, strftime("%Y-%m-%d %H:%M:%S")))
remarks.append("")
remarks.append("Settings:")
remarks.append(" * Maximum number of strains : {}".format(self.max_strains))
remarks.append(" * Global base infection rate : {}".format(self.global_infection_rate))
remarks.append(" * Global base recovery rate : {}".format(self.global_recovery_rate))
remarks.append(" * Global mortality rate : {}".format(self.global_mortality_rate))
remarks.append(" * Strain mutation rate : {}".format(self.strain_mutation_rate))
remarks.append(" * Cross-susceptibility model : '{}'".format(self.cross_susceptibility_model))
remarks.append(" * Cross-susceptibility factor : {}".format(self.cross_susceptibility_factor))
remarks.append(" * Initial susceptible population size : {}".format(self.initial_naive_population_size))
if self.simulation_run_time is None:
remarks.append(" * Simulation maximum run time : not constrained")
else:
remarks.append(" * Simulation maximum run time : {}".format(self.simulation_run_time))
if self.naive_population_proportion_termination_threshold is None:
remarks.append(" * Naive population proportion termination threshold : not set")
else:
remarks.append(" * Naive population proportion termination threshold : {}".format(self.naive_population_proportion_termination_threshold))
remarks.append("")
reactions_xml = []
infection_reactions = []
mutation_reactions = []
recovery_reactions = []
## calculate reactions
num_contacts_with_infection_reaction = 0
num_reactants_with_mutations = 0
num_reactants_with_recovery = 0
lineage_seed = None
for hidx1, host1 in enumerate(self.host_types):
if hidx1 > 0:
reactions = self.calc_mutation_reactions(host1)
if reactions:
num_reactants_with_mutations += 1
mutation_reactions.extend(reactions)
reactions = self.calc_recovery_reactions(host1)
if reactions:
num_reactants_with_recovery += 1
recovery_reactions.extend(reactions)
for hidx2 in range(hidx1+1, len(self.host_types)):
host2 = self.host_types[hidx2]
for reactant1, reactant2 in ((host1, host2), (host2, host1)):
reactions = self.calc_infection_reactions(reactant1, reactant2)
if reactions:
num_contacts_with_infection_reaction += 1
infection_reactions.extend(reactions)
# set up infections
normalized_base_infection_rate = 1.0/num_contacts_with_infection_reaction * self.global_infection_rate
infection_xml = []
check_reaction_rate = 0.0
for ir in infection_reactions:
infection_xml.append(ir.xml(normalized_base_infection_rate=normalized_base_infection_rate))
check_reaction_rate += ir.realized_reaction_rate
remarks.append("Note: Realized global infection rate is {} (requested global infection rate: {})".format(check_reaction_rate,
self.global_infection_rate))
# assert abs(check_reaction_rate - self.global_infection_rate) < 1e-2, \
# "{} vs. {}".format(check_reaction_rate, self.global_infection_rate)
infection_xml = REACTION_GROUP_XML_TEMPLATE.format(
reaction_group_name="Infections",
reactions="\n".join(infection_xml))
reactions_xml.append(infection_xml)
# mutation xml
if num_reactants_with_mutations:
normalized_base_mutation_rate = 1.0/num_reactants_with_mutations * self.strain_mutation_rate
mutation_xml = []
check_reaction_rate = 0.0
for mr in mutation_reactions:
mutation_xml.append(mr.xml(normalized_base_rate=normalized_base_mutation_rate))
check_reaction_rate += mr.realized_reaction_rate
remarks.append("Note: Realized global mutation rate is {} (requested global mutation rate: {})".format(check_reaction_rate,
self.strain_mutation_rate))
# assert abs(check_reaction_rate - self.strain_mutation_rate) < 1e-8
mutation_xml = REACTION_GROUP_XML_TEMPLATE.format(
reaction_group_name="Mutations",
reactions="\n".join(mutation_xml))
reactions_xml.append(mutation_xml)
# recovery xml
if num_reactants_with_recovery:
normalized_base_recovery_rate = 1.0/num_reactants_with_recovery * self.global_recovery_rate
recovery_xml = []
check_reaction_rate = 0.0
for mr in recovery_reactions:
recovery_xml.append(mr.xml(normalized_base_rate=normalized_base_recovery_rate))
check_reaction_rate += mr.realized_reaction_rate
remarks.append("Note: Realized global recovery rate is {} (requested global recovery rate: {})".format(check_reaction_rate,
self.global_recovery_rate))
# assert abs(check_reaction_rate - self.global_recovery_rate) < 1e-8
recovery_xml = REACTION_GROUP_XML_TEMPLATE.format(
reaction_group_name="Recoveries",
reactions="\n".join(recovery_xml))
reactions_xml.append(recovery_xml)
reactions_xml = "\n".join(reactions_xml)
if self.simulation_run_time is not None:
simulation_time_xml = SIMULATION_TIME_XML_TEMPLATE.format(simulation_time=self.simulation_run_time)
else:
simulation_time_xml = ""
if self.naive_population_proportion_termination_threshold is not None:
population_end_condition_xml = POPULATION_END_CONDITION_XML_TEMPLATE.format(
# population_end_condition_threshold=self.initial_naive_population_size * (1-self.global_infection_rate),
population_end_condition_threshold = self.initial_naive_population_size * self.naive_population_proportion_termination_threshold
)
else:
population_end_condition_xml = ""
if self.seed_simulator_rng:
random_seed_xml = SIMULATION_RANDOM_SEED_XML_TEMPLATE.format(self.rng.randit(0, sys.maxint))
else:
random_seed_xml = ""
remarks_xml = "<!--\n\n" + "\n".join(remarks) + "\n-->\n"
template = MASTER_XML_TEMPLATE.format(
remarks=remarks_xml,
simulation_time = simulation_time_xml,
random_seed=random_seed_xml,
num_infected_host_types=self.num_infected_host_types,
num_uninfected_host_types=self.num_uninfected_host_types,
reaction_directives=reactions_xml,
initial_naive_population_size=self.initial_naive_population_size - 1,
lineage_seed=lineage_seed,
population_end_condition=population_end_condition_xml,
output_prefix="{output_prefix}"
)
return template
def generate_xml(self, simulation_output_prefix="out"):
xml = self.generate_xml_template()
xml = xml.format(output_prefix=os.path.basename(simulation_output_prefix))
return xml
def calc_mutation_reactions(self, host):
potentials = []
current_strain_idxs = []
potential_new_strain_idxs = []
host_infection_profile = host.infection_profile
for idx, state in enumerate(host_infection_profile):
if state == "1":
current_strain_idxs.append(idx)
elif state == "0":
potential_new_strain_idxs.append(idx)
if not current_strain_idxs or not potential_new_strain_idxs:
return []
base_prob = 1.0 / len(potential_new_strain_idxs)
reactions = []
for idx in potential_new_strain_idxs:
product_infection_profile = list(host_infection_profile)
product_infection_profile[idx] = "1"
product_immunity_profile = list(host.immunity_profile)
product_immunity_profile[idx] = "1"
product = self.profile_host_map[("".join(product_immunity_profile), "".join(product_infection_profile))]
mr = NewStrainReaction(
progenitor=host,
product=product,
adjustment_factor=base_prob)
reactions.append(mr)
return reactions
def calc_recovery_reactions(self, host):
potentials = []
current_strain_idxs = []
host_infection_profile = host.infection_profile
for idx, state in enumerate(host_infection_profile):
if state == "1":
current_strain_idxs.append(idx)
if not current_strain_idxs:
return []
base_prob = 1.0 / len(current_strain_idxs)
reactions = []
for idx in current_strain_idxs:
product_infection_profile = list(host_infection_profile)
product_infection_profile[idx] = "0"
product_immunity_profile = host.immunity_profile
# product_immunity_profile[idx] = "1"
product = self.profile_host_map[(product_immunity_profile, "".join(product_infection_profile))]
mr = NewStrainReaction(
progenitor=host,
product=product,
adjustment_factor=base_prob)
reactions.append(mr)
return reactions
def calc_infection_reactions(self, source, recipient):
potentials = []
# check to see what possible strains can infect the recipient from source
for idx, state in enumerate(recipient.immunity_profile):
# recipient can be infected by strain present in source if:
# recipient has not been exposed to this strain before ("1" for
# current infection, "2" recovered)
if state == "0" and source.infection_profile[idx] == "1":
potentials.append(idx)
if not potentials:
return []
reactions = []
# for each possible strain, calculate the resulting epidemiological profile
multiple_potential_strain_in_source_adjustment_factor = 1.0/len(potentials);
for idx in potentials:
product_infection_profile = list(recipient.infection_profile)
product_infection_profile[idx] = "1"
product_immunity_profile = list(recipient.immunity_profile)
product_immunity_profile[idx] = "1"
product = self.profile_host_map[("".join(product_immunity_profile), "".join(product_infection_profile))]
reaction = InfectionReaction(
source=source,
recipient=recipient,
product=product,
multiple_potential_strain_in_source_adjustment_factor=multiple_potential_strain_in_source_adjustment_factor,
cross_susceptibility_factor=self.calc_cross_susceptibility_factor(recipient, idx)
)
reactions.append(reaction)
return reactions
def get_distance_from_nearest_active_strain(self, profile, idx):
fwd = profile[idx+1:] + profile[:idx]
rev = fwd[-1::-1]
d1 = fwd.find('1') + 1
d2 = rev.find('1') + 1
if d1 and d2:
return min(d1, d2)
elif d1:
return d1
else:
return d2
def calc_cross_susceptibility_factor(self, recipient, idx):
if recipient.immunity_profile[idx] != "0":
return 0.0
if self.cross_susceptibility_model == "simple":
if "1" in recipient.immunity_profile:
return self.cross_susceptibility_factor
else:
return 1.0
elif self.cross_susceptibility_model == "distance":
d = self.get_distance_from_nearest_active_strain(recipient.immunity_profile, idx)
if not d:
return 1.0
return self.cross_susceptibility_factor * d
# `epi_type`: a number that will be converted to an array of base-3 bits,
# where each element codes the epidemiological state of the host with
# respect to a particular virus strain, indicating:
# - 0: global/susceptible
# - 1: infected
# - 2: recovered
def host_index_to_profile(self, host_index):
return "{0:0>{1}}".format(number_to_base_str(host_index, Epidemiology.num_exposure_states), self.max_strains)
def host_profile_to_index(self, host_profile):
return int(host_profile, Epidemiology.num_exposure_states)
def bitmask_to_str(self, b):
return "{0:0>{1}b}".format(b, self.max_strains)
def main():
"""
Main CLI handler.
"""
parser = argparse.ArgumentParser()
epidemiology_options = parser.add_argument_group("Epidemiological System")
epidemiology_options.add_argument("--max-strains", type=int, default=4,
help="Maximum number of strains (default=%(default)s)")
epidemiology_options.add_argument("--strain-mutation-rate", type=float, default=0.2,
help="Rate of new strain emergence (default=%(default)s)")
epidemiology_options.add_argument("--infection-rate", type=float, default=0.2,
help="Global infection rate (across all strains; default=%(default)s)")
epidemiology_options.add_argument("--recovery-rate", type=float, default=0.2,
help="Global recovery rate (across all strains; default=%(default)s)")
epidemiology_options.add_argument("--mortality-rate", type=float, default=0.2,
help="Global mortality rate (across all strains; default=%(default)s)")
epidemiology_options.add_argument("--cross-susceptibility-model", type=str, default="simple",
help="""\
How previous exposure to other strains affects infection rate when exposed to
a new strain.
'simple': [DEFAULT] infection rate of any new strain is multiplied by
`cross-susceptibility-factor` if already exposed to at least one other strain
'distance': infection rate of any new strain multiplied
`cross-susceptibility-factor` weighted by distance of the new strain
from strains to which recipient has been exposed
(rate = base_rate * `cross-susceptibility-factor` * distance);
""")
# 'simple-additive': infection rate of any new strain is reduced by
# `cross-immunity-factor` multiplied by the number of strains to which the
# recipient has already been exposed (rate = base_rate * (n *
# `cross_imminity_factor);
epidemiology_options.add_argument("--cross-susceptibility-factor",
type=float, default=1.0,
help="""\
If < 1.0 then susceptibility to new strains is reduced following exposure
(some degree of cross-immunity); if > 1.0 then susceptibility to new strains
is increased following exposure (enhanced immunity, as with, e.g. dengue);
default=%(default)s""")
epidemiology_options.add_argument("--initial-naive-population-size", type=int, default=1000,
help="Initial naive host population size (default=%(default)s)")
simulation_options = parser.add_argument_group("Simulation Options")
simulation_options.add_argument("--run-time", type=int, default=None,
help="Simulation run time (default=%(default)s)")
simulation_options.add_argument("--termination-threshold", type=float, default=None,
help="Terminate simulation when naive (unexposed) population proportion falls below this value (default=%(default)s)")
simulation_options.add_argument("--num-replicates", type=int, default=1,
help="Number of replicate simulation run files to generate (default=%(default)s)")
simulation_options.add_argument("--output-prefix", type=str, default="master-phylodynamics",
help="Output prefix for simulation products (default=%(default)s)")
simulation_options.add_argument("--create-job-files", action="store_true", default=False,
help="create job script files")
args = parser.parse_args()
if not args.termination_threshold and not args.run_time:
sys.exit("Need to specify at least one of '--run-time' or '--termination-threshold'")
generator = Epidemiology(args.max_strains)
generator.global_infection_rate = args.infection_rate
generator.global_recovery_rate = args.recovery_rate
generator.global_mortality_rate = args.mortality_rate
generator.strain_mutation_rate = args.strain_mutation_rate
generator.cross_susceptibility_model = args.cross_susceptibility_model
generator.cross_susceptibility_factor = args.cross_susceptibility_factor
generator.initial_naive_population_size = args.initial_naive_population_size
generator.simulation_run_time = args.run_time
generator.naive_population_proportion_termination_threshold = args.termination_threshold
generator.build_host_types()
template = generator.generate_xml_template()
for rep in range(args.num_replicates):
output_prefix = args.output_prefix + ".{:03}".format(rep + 1)
xml = template.format(output_prefix=os.path.basename(output_prefix))
xml_path = output_prefix + ".xml"
f = open(xml_path, "w")
f.write(xml)
f.close()
if args.create_job_files:
jobf = open(output_prefix + ".sge", "w")
jobf.write(JOB_TEMPLATE.format(xml_path=os.path.basename(xml_path)))
jobf.close()
if __name__ == "__main__":
main()
# g = Epidemiology(4)
# print g.generate_xml()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment