Skip to content

Instantly share code, notes, and snippets.

@mamelara
Created April 21, 2020 19:33
Show Gist options
  • Save mamelara/ba486f8e65e7c0f2591690ba51f8b3bd to your computer and use it in GitHub Desktop.
Save mamelara/ba486f8e65e7c0f2591690ba51f8b3bd to your computer and use it in GitHub Desktop.
def worker(ctx: object, heartbeat_interval_param: int, custom_log_dir: str,
custom_job_log_dir_name: str, pool_name_param: str, dry_run: bool,
slurm_job_id_param: int, worker_type_param: str, cluster_name_param: str,
worker_clone_time_rate_param: float, num_workers_per_node_param: int,
worker_id_param: str, charging_account_param: str,
num_nodes_to_request_param: int, num_cores_to_request_param: int,
constraint_param: str, mem_per_node_to_request_param: str,
mem_per_cpu_to_request_param: str,
qos_param: str, job_time_to_request_param: str) -> int:
global CONFIG
CONFIG = ctx.obj['config']
debug = ctx.obj['debug']
# config file has precedence
config_debug = CONFIG.configparser.getboolean("SITE", "debug")
if config_debug:
debug = config_debug
global DEBUG
DEBUG = debug
global WORKER_TYPE
WORKER_TYPE = CONFIG.constants.WORKER_TYPE
global HB_MSG
HB_MSG = CONFIG.constants.HB_MSG
global VERSION
VERSION = CONFIG.constants.VERSION
global COMPUTE_RESOURCES
COMPUTE_RESOURCES = CONFIG.constants.COMPUTE_RESOURCES
global TASK_TYPE
TASK_TYPE = CONFIG.constants.TASK_TYPE
global DONE_FLAGS
DONE_FLAGS = CONFIG.constants.DONE_FLAGS
global NUM_WORKER_PROCS
NUM_WORKER_PROCS = CONFIG.constants.NUM_WORKER_PROCS
global TASK_KILL_TIMEOUT_MINUTE
TASK_KILL_TIMEOUT_MINUTE = CONFIG.constants.TASK_KILL_TIMEOUT_MINUTE
global CNAME
CNAME = CONFIG.configparser.get("SITE", "instance_name")
global JTM_HOST_NAME
JTM_HOST_NAME = CONFIG.configparser.get("SITE", "jtm_host_name")
global JTM_INNER_REQUEST_Q
JTM_INNER_REQUEST_Q = CONFIG.configparser.get("JTM", "jtm_inner_request_q")
global CTR
CTR = CONFIG.configparser.getfloat("JTM", "clone_time_rate")
global JTM_INNER_MAIN_EXCH
JTM_INNER_MAIN_EXCH = CONFIG.configparser.get("JTM", "jtm_inner_main_exch")
global JTM_CLIENT_HB_EXCH
JTM_CLIENT_HB_EXCH = CONFIG.configparser.get("JTM", "jtm_client_hb_exch")
global JTM_WORKER_HB_EXCH
JTM_WORKER_HB_EXCH = CONFIG.configparser.get("JTM", "jtm_worker_hb_exch")
global CLIENT_HB_Q_POSTFIX
CLIENT_HB_Q_POSTFIX = CONFIG.configparser.get("JTM", "client_hb_q_postfix")
global WORKER_HB_Q_POSTFIX
WORKER_HB_Q_POSTFIX = CONFIG.configparser.get("JTM", "worker_hb_q_postfix")
global JTM_TASK_KILL_EXCH
JTM_TASK_KILL_EXCH = CONFIG.configparser.get("JTM", "jtm_task_kill_exch")
global JTM_TASK_KILL_Q
JTM_TASK_KILL_Q = CONFIG.configparser.get("JTM", "jtm_task_kill_q")
global JTM_WORKER_POISON_EXCH
JTM_WORKER_POISON_EXCH = CONFIG.configparser.get("JTM", "jtm_worker_poison_exch")
global JTM_WORKER_POISON_Q
JTM_WORKER_POISON_Q = CONFIG.configparser.get("JTM", "jtm_worker_poison_q")
global NUM_PROCS_CHECK_INTERVAL
NUM_PROCS_CHECK_INTERVAL = CONFIG.configparser.getfloat("JTM", "num_procs_check_interval")
global ENV_ACTIVATION
ENV_ACTIVATION = CONFIG.configparser.get("JTM", "env_activation")
WORKER_CONFIG_FILE = CONFIG.configparser.get("JTM", "worker_config_file")
RMQ_HOST = CONFIG.configparser.get("RMQ", "host")
RMQ_PORT = CONFIG.configparser.get("RMQ", "port")
USER_NAME = CONFIG.configparser.get("SITE", "user_name")
PRODUCTION = False
if CONFIG.configparser.get("JTM", "run_mode") == "prod":
PRODUCTION = True
JOBTIME = CONFIG.configparser.get("SLURM", "jobtime")
CONSTRAINT = CONFIG.configparser.get("SLURM", "constraint")
CHARGE_ACCNT = CONFIG.configparser.get("SLURM", "charge_accnt")
QOS = CONFIG.configparser.get("SLURM", "qos")
PARTITION = CONFIG.configparser.get("SLURM", "partition")
MEMPERCPU = CONFIG.configparser.get("SLURM", "mempercpu")
MEMPERNODE = CONFIG.configparser.get("SLURM", "mempernode")
NWORKERS = CONFIG.configparser.getint("JTM", "num_workers_per_node")
NCPUS = CONFIG.configparser.getint("SLURM", "ncpus")
global FILE_CHECK_INTERVAL
FILE_CHECK_INTERVAL = CONFIG.configparser.getfloat("JTM", "file_check_interval")
global FILE_CHECKING_MAX_TRIAL
FILE_CHECKING_MAX_TRIAL = CONFIG.configparser.getint("JTM", "file_checking_max_trial")
global FILE_CHECK_INT_INC
FILE_CHECK_INT_INC = CONFIG.configparser.getfloat("JTM", "file_check_int_inc")
# Job dir setting
job_script_dir_name = os.path.join(CONFIG.configparser.get("JTM", "log_dir"), "job")
if custom_job_log_dir_name:
job_script_dir_name = custom_job_log_dir_name
make_dir(job_script_dir_name)
# Log dir setting
log_dir_name = os.path.join(CONFIG.configparser.get("JTM", "log_dir"), "log")
if custom_log_dir:
log_dir_name = custom_log_dir
make_dir(log_dir_name)
print("JTM Worker, version: {}".format(VERSION))
# Set uniq worker id if worker id is provided in the params
if worker_id_param:
global UNIQ_WORKER_ID
UNIQ_WORKER_ID = worker_id_param
# Logger setting
log_level = "info"
if DEBUG:
log_level = "debug"
setup_custom_logger(log_level, log_dir_name,
1, 1,
worker_id=UNIQ_WORKER_ID)
logger.info("\n*****************\nDebug mode is %s\n*****************"
% ("ON" if DEBUG else "OFF"))
# # Todo: site specific setting --> remove
# CORI_KNL_CHARGE_ACCNT = CONFIG.configparser.get("SLURM", "knl_charge_accnt")
# CORI_KNL_QOS = CONFIG.configparser.get("SLURM", "knl_qos")
hearbeat_interval = CONFIG.configparser.getfloat("JTM", "worker_hb_send_interval")
logger.info("Set jtm log file location to %s", log_dir_name)
logger.info("Set jtm job file location to %s", job_script_dir_name)
logger.info("RabbitMQ broker: %s", RMQ_HOST)
logger.info("RabbitMQ port: %s", RMQ_PORT)
logger.info("Pika version: %s", pika.__version__)
logger.info("JTM user name: %s", USER_NAME)
logger.info("Unique worker ID: %s", UNIQ_WORKER_ID)
logger.info("\n*****************\nRun mode is %s\n*****************"
% ("PROD" if PRODUCTION else "DEV"))
logger.info("env activation: %s", ENV_ACTIVATION)
logger.info("JTM config file: %s" % (CONFIG.config_file))
# Slurm config
num_nodes_to_request = 0
if num_nodes_to_request_param:
num_nodes_to_request = num_nodes_to_request_param
# Todo
# Cori and JGI Cloud are exclusive allocation. So this is not needed.
# assert mem_per_node_to_request_param is not None, "-N needs --mem-per-cpu (-mc) setting."
# 11.13.2018 decided to remove all default values from argparse
num_workers_per_node = num_workers_per_node_param if num_workers_per_node_param else NWORKERS
assert num_workers_per_node > 0
mem_per_cpu_to_request = mem_per_cpu_to_request_param if mem_per_cpu_to_request_param else MEMPERCPU
mem_per_node_to_request = mem_per_node_to_request_param if mem_per_node_to_request_param else MEMPERNODE
assert mem_per_cpu_to_request
assert mem_per_node_to_request
num_cpus_to_request = num_cores_to_request_param if num_cores_to_request_param else NCPUS
assert num_cpus_to_request
# Set CPU affinity for limiting the number of cores to use
if worker_type_param != "manual" and worker_id_param and worker_id_param.find('_') != -1:
# ex)
# total_cpu_num = 32, num_workers_per_node_param = 4
# split_cpu_num = 8
# worker_number - 1 == 0 --> [0, 1, 2, 3, 4, 5, 6, 7]
# worker_number - 1 == 1 --> [8, 9, 10, 11, 12, 13, 14, 15]
proc = psutil.Process(PARENT_PROCESS_ID)
try:
# Use the appended worker id number as worker_number
# ex) 5wZwyCM8rxgNtERsU8znJU_1 --> extract "1" --> worker number
worker_number = int(worker_id_param.split('_')[-1]) - 1
except ValueError:
logger.exception("Not an expected worker ID. Cancelling CPU affinity setting")
else:
# Note: may need to use num_cpus_to_request outside LBL
total_cpu_num = psutil.cpu_count()
logger.info("Total number of cores available: {}".format(total_cpu_num))
split_cpu_num = int(total_cpu_num / num_workers_per_node)
cpu_affinity_list = list(range(worker_number * split_cpu_num,
((worker_number + 1) * split_cpu_num)))
logger.info("Set CPU affinity to use: {}".format(cpu_affinity_list))
try:
proc.cpu_affinity(cpu_affinity_list)
except Exception as e:
logger.exception("Failed to set the CPU usage limit: %s" % (e))
sys.exit(1)
# Set memory upper limit
# Todo: May need to use all free_memory on Cori and Lbl
system_free_mem_bytes = get_free_memory()
logger.info("Total available memory (MBytes): %d"
% (system_free_mem_bytes / 1024.0 / 1024.0))
if worker_type_param != "manual" and num_workers_per_node > 1:
try:
mem_per_node_to_request_byte = int(mem_per_node_to_request.lower()
.replace("gb", "")
.replace("g", "")) * 1024.0 * 1024.0 * 1024.0
logger.info("Requested memory for this worker (MBytes): %d"
% (mem_per_node_to_request_byte / 1024.0 / 1024.0))
# if requested mempernode is larger than system avaiable mem space
if system_free_mem_bytes < mem_per_node_to_request_byte:
logger.critical("Requested memory space is not available")
logger.critical("Available space: %d (MBytes)"
% (system_free_mem_bytes / 1024.0 / 1024.0))
logger.critical("Requested space: %d (MBytes)"
% (mem_per_node_to_request_byte / 1024.0 / 1024.0))
# Option 1
# mem_per_node_to_request_byte = system_free_mem_bytes
# Option 2
raise MemoryError
MEM_LIMIT_PER_WORKER_BYTES = int(mem_per_node_to_request_byte /
num_workers_per_node)
except Exception as e:
logger.exception("Failed to compute the memory limit: %s", mem_per_node_to_request)
logger.exception(e)
sys.exit(1)
try:
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (MEM_LIMIT_PER_WORKER_BYTES, hard))
logger.info("Set the memory usage upper limit (MBytes): %d"
% (MEM_LIMIT_PER_WORKER_BYTES / 1024.0 / 1024.0))
except Exception as e:
logger.exception("Failed to set the memory usage limit: %s", mem_per_node_to_request)
logger.exception(e)
sys.exit(1)
job_time_to_request = job_time_to_request_param if job_time_to_request_param else JOBTIME
constraint = constraint_param if constraint_param else CONSTRAINT
charging_account = charging_account_param if charging_account_param else CHARGE_ACCNT
qos = qos_param if qos_param else QOS
global THIS_WORKER_TYPE
THIS_WORKER_TYPE = worker_type_param
job_name = "jtm_worker_" + pool_name_param
# Set task queue name
inner_task_request_queue = None
if heartbeat_interval_param:
hearbeat_interval = heartbeat_interval_param
# Start hb receive thread
tp_name = ""
if pool_name_param:
tp_name = pool_name_param
assert pool_name_param is not None, "User pool name is not set"
inner_task_request_queue = JTM_INNER_REQUEST_Q + "." + pool_name_param
worker_clone_time_rate = worker_clone_time_rate_param if worker_clone_time_rate_param else CTR
if THIS_WORKER_TYPE in ("static", "dynamic"):
assert cluster_name_param != "" and \
cluster_name_param != "local", "Static or dynamic worker needs a cluster setting (-cl)."
slurm_job_id = slurm_job_id_param
cluster_name = cluster_name_param
if cluster_name == "cori" and mem_per_cpu_to_request != "" and \
float(mem_per_cpu_to_request.replace("GB", "").replace("G", "").replace("gb", "")) > 1.0:
logger.critical("--mem-per-cpu in Cori shouldn't be larger than 1GB. User '--mem' instead.")
sys.exit(1)
logger.info("RabbitMQ broker: %s", RMQ_HOST)
logger.info("Task queue name: %s", inner_task_request_queue)
logger.info("Worker type: %s", THIS_WORKER_TYPE)
if slurm_job_id == 0 and THIS_WORKER_TYPE in ["static", "dynamic"]:
batch_job_script_file = os.path.join(job_script_dir_name, "jtm_%s_worker_%s.job" %
(THIS_WORKER_TYPE, UNIQ_WORKER_ID))
batch_job_script_str = ""
batch_job_misc_params = ""
worker_config = CONFIG.config_file if CONFIG else ""
if WORKER_CONFIG_FILE:
worker_config = WORKER_CONFIG_FILE
if cluster_name in ("cori", "lawrencium", "jgi_cloud", "jaws_lbl_gov", "lbl", "jgi_cluster"):
with open(batch_job_script_file, "w") as jf:
batch_job_script_str += "#!/bin/bash -l"
if cluster_name in ("cori"):
if num_nodes_to_request_param:
batch_job_script_str += """
#SBATCH -N %(num_nodes_to_request)d
#SBATCH --mem=%(mem)s""" % dict(num_nodes_to_request=num_nodes_to_request, mem=mem_per_node_to_request)
batch_job_misc_params += " -N %(num_nodes_to_request)d -m %(mem)s" % \
dict(num_nodes_to_request=num_nodes_to_request,
mem=mem_per_node_to_request)
if num_cores_to_request_param:
batch_job_script_str += """
#SBATCH -c %(num_cores)d""" % dict(num_cores=num_cpus_to_request)
batch_job_misc_params += " -c %(num_cores)d" % \
dict(num_cores=num_cpus_to_request)
else:
batch_job_script_str += """
#SBATCH -c %(num_cores)d""" % dict(num_cores=num_cpus_to_request)
batch_job_misc_params += " -c %(num_cores)d" % \
dict(num_cores=num_cpus_to_request)
if mem_per_node_to_request:
batch_job_script_str += """
#SBATCH --mem=%(mem)s""" % dict(mem=mem_per_node_to_request)
batch_job_misc_params += " -m %(mem)s " % \
dict(mem=mem_per_node_to_request)
else:
batch_job_script_str += """
#SBATCH --mem-per-cpu=%(mempercore)s""" % dict(mempercore=mem_per_cpu_to_request)
batch_job_misc_params += " -mc %(mempercore)s" % \
dict(mempercore=mem_per_cpu_to_request)
if worker_id_param:
batch_job_misc_params += " -wi %(worker_id)s_${i}" % \
dict(worker_id=UNIQ_WORKER_ID)
###########################
if 1:
# Need to set both --qos=genepool (or genepool_shared) _and_ -A fungalp
# OR
# no qos _and_ -A m342 _and_ -C haswell
# Note: currently constraint in ["haswell" | "knl"]
if constraint == "haswell":
if qos_param:
batch_job_script_str += """
#SBATCH -q %(qosname)s""" % dict(qosname=qos)
batch_job_misc_params += " -q %(qosname)s" % dict(qosname=qos)
else:
batch_job_script_str += """
#SBATCH -q %(qosname)s""" % dict(qosname=qos)
batch_job_script_str += """
#SBATCH -C haswell"""
if charging_account == "m342":
batch_job_misc_params += " -A %(sa)s" % dict(sa="m342")
batch_job_script_str += """
#SBATCH -A %(charging_account)s""" % dict(charging_account=charging_account)
elif constraint == "knl":
# Note: Basic KNL setting = "-q regular -A m342 -C knl"
#
# Note: KNL MCDRAM setting -> cache or flat
# cache mode - MCDRAM is configured entirely as a last-level cache (L3)
# flat mode - MCDRAM is configured entirely as addressable memory
# ex) #SBATCH -C knl,quad,cache
# ex) #SBATCH -C knl,quad,flat
# --> srun <srun options> numactl -p 1 yourapplication.x
#
# Note: for knl, we should use m342
#
# Note: for knl, charging_account can be set via runtime (like lanl, m3408)
#
batch_job_script_str += """
#SBATCH -C knl
#SBATCH -A %(charging_account)s
#SBATCH -q %(qosname)s""" % \
dict(charging_account=charging_account, qosname=qos)
batch_job_misc_params += " -A %(charging_account)s -q %(qosname)s" % \
dict(charging_account=charging_account, qosname=qos)
elif constraint == "skylake":
# Example usage with skylakte for Brian F.
# 120G
# ======================
# -t 48:00:00 -c 16 --job-name=mga-627530 --mem=115G --qos=genepool_special
# --exclusive -A gtrqc
#
# 250G
# ======================
# -t 96:00:00 -c 72 --job-name=mga-627834 --mem=240G -C skylake --qos=jgi_exvivo
# -A gtrqc
#
# 500G
# ======================
# -t 96:00:00 -c 72 --job-name=mga-627834 --mem=240G -C skylake --qos=jgi_exvivo
# -A gtrqc
batch_job_script_str += """
#SBATCH -C skylake
#SBATCH -A %(charging_account)s
#SBATCH -q %(qosname)s""" % \
dict(charging_account=charging_account, qosname=qos)
batch_job_misc_params += " -A %(charging_account)s -q %(qosname)s" % \
dict(charging_account=charging_account, qosname=qos)
excl_param = ""
if constraint != "skylake":
excl_param = "#SBATCH --exclusive"
tq_param = ""
if pool_name_param:
tq_param = "-p " + pool_name_param
batch_job_script_str += """
#SBATCH -t %(wall_time)s
#SBATCH --job-name=%(job_name)s
#SBATCH -o %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.out
#SBATCH -e %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.err
%(exclusive)s
module unload python
%(env_activation_cmd)s
%(export_jtm_config_file)s
for i in {1..%(num_workers_per_node)d}
do
echo "jobid: $SLURM_JOB_ID"
jtm %(set_jtm_config_file)s %(debug)s worker --slurm_job_id $SLURM_JOB_ID \
-cl cori \
-wt %(worker_type)s \
-t %(wall_time)s \
--clone_time_rate %(clone_time_rate)f %(task_queue)s \
--num_worker_per_node %(num_workers_per_node)d \
-C %(constraint)s \
-m %(mem)s \
%(other_params)s &
sleep 1
done
wait
""" % \
dict(debug="--debug" if DEBUG else "",
wall_time=job_time_to_request,
job_dir=job_script_dir_name,
worker_id=UNIQ_WORKER_ID,
worker_type=THIS_WORKER_TYPE,
clone_time_rate=worker_clone_time_rate,
task_queue=tq_param,
num_workers_per_node=num_workers_per_node,
env_activation_cmd=ENV_ACTIVATION,
other_params=batch_job_misc_params,
constraint=constraint,
mem=mem_per_node_to_request,
job_name=job_name,
exclusive=excl_param,
export_jtm_config_file="export JTM_CONFIG_FILE=%s"
% worker_config,
set_jtm_config_file="--config=%s"
% worker_config)
elif cluster_name in ("lawrencium", "jgi_cloud", "jaws_lbl_gov", "jgi_cluster", "lbl"):
if worker_id_param:
batch_job_misc_params += " -wi %(worker_id)s_${i}" \
% dict(worker_id=UNIQ_WORKER_ID)
tp_param = ""
if pool_name_param:
tp_param = "-p " + pool_name_param
part_param = ""
if cluster_name == "lawrencium":
part_param = PARTITION
else:
part_param = PARTITION
qos_param = ""
if cluster_name == "lawrencium":
qos_param = QOS
else:
qos_param = QOS
charge_param = ""
if cluster_name == "lawrencium":
charge_param = CHARGE_ACCNT
else:
charge_param = CHARGE_ACCNT
nnode_param = 1
if num_nodes_to_request_param:
nnode_param = num_nodes_to_request
mnode_param = "#SBATCH --mem=%(mem)s" \
% dict(mem=mem_per_node_to_request)
batch_job_script_str += """
#SBATCH --time=%(wall_time)s
#SBATCH --job-name=%(job_name)s
#SBATCH --partition=%(partition_name)s
#SBATCH --qos=%(qosname)s
#SBATCH --account=%(charging_account)s
#SBATCH --nodes=%(num_nodes_to_request)d
%(mem_per_node_setting)s
#SBATCH -o %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.out
#SBATCH -e %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.err
%(env_activation_cmd)s
%(export_jtm_config_file)s
for i in {1..%(num_workers_per_node)d}
do
echo "jobid: $SLURM_JOB_ID"
jtm %(set_jtm_config_file)s %(debug)s worker --slurm_job_id $SLURM_JOB_ID \
-cl %(lbl_cluster_name)s \
-wt %(worker_type)s \
-t %(wall_time)s \
--clone_time_rate %(clone_time_rate)f %(task_queue)s \
--num_worker_per_node %(num_workers_per_node)d \
-m %(mem)s \
%(other_params)s &
sleep 1
done
wait
""" % \
dict(debug="--debug" if DEBUG else "",
wall_time=job_time_to_request,
job_name=job_name,
partition_name=part_param,
qosname=qos_param,
charging_account=charge_param,
num_nodes_to_request=nnode_param,
mem_per_node_setting=mnode_param,
worker_id=UNIQ_WORKER_ID,
job_dir=job_script_dir_name,
env_activation_cmd=ENV_ACTIVATION,
num_workers_per_node=num_workers_per_node,
mem=mem_per_node_to_request,
lbl_cluster_name=cluster_name,
worker_type=THIS_WORKER_TYPE,
clone_time_rate=worker_clone_time_rate,
task_queue=tp_param,
other_params=batch_job_misc_params,
export_jtm_config_file="export JTM_CONFIG_FILE=%s"
% worker_config,
set_jtm_config_file="--config=%s"
% worker_config)
jf.writelines(batch_job_script_str)
os.chmod(batch_job_script_file, 0o775)
if dry_run:
print(batch_job_script_str)
sys.exit(0)
sbatch_cmd = "sbatch --parsable %s" % (batch_job_script_file)
_, _, ec = run_sh_command(sbatch_cmd, log=logger)
assert ec == 0, "Failed to run 'jtm worker' to sbatch dynamic worker."
return ec
elif cluster_name == "aws":
pass
# If it's spawned by sbatch
# Todo: need to record job_id, worker_id, worker_type, starting_time, wallclocktime
# scontrol show jobid -dd <jobid> ==> EndTime
# scontrol show jobid <jobid> ==> EndTime
# sstat --format=AveCPU,AvePages,AveRSS,AveVMSize,JobID -j <jobid> --allsteps
#
# if endtime - starttime <= 10%, execute sbatch again
# if slurm_job_id != 0 and THIS_WORKER_TYPE == "static":
# logger.debug("worker_type: {}".format(THIS_WORKER_TYPE))
# logger.debug("slurm_job_id: {}".format(slurm_job_id))
# Dynamic workers creates [[two]] children when it approaches to the wallclocktime limit
# considering the task queue length
# Also, maintain the already requested number of workers
# if no more workers needed, it won't call sbatch
# elif slurm_job_id != 0 and THIS_WORKER_TYPE == "dynamic":
# logger.debug("worker_type: {}".format(THIS_WORKER_TYPE))
# logger.debug("slurm_job_id: {}".format(slurm_job_id))
# Remote broker (rmq.nersc.gov)
rmq_conn = RmqConnectionHB(config=CONFIG)
conn = rmq_conn.open()
ch = conn.channel()
# ch.confirm_delivery()
ch.exchange_declare(exchange=JTM_INNER_MAIN_EXCH,
exchange_type="direct",
passive=False,
durable=True,
auto_delete=False)
# Declare task receiving queue (client --> worker)
#
# If you have a queueu that is durable, RabbitMQ will never lose our queue.
# If you have a queue that is exclusive, then when the channel that declared
# the queue is closed, the queue is deleted.
# If you have a queue that is auto-deleted, then when there are no
# subscriptions left on that queue it will be deleted.
#
ch.queue_declare(queue=inner_task_request_queue,
durable=True,
exclusive=False,
auto_delete=True)
ch.queue_bind(exchange=JTM_INNER_MAIN_EXCH,
queue=inner_task_request_queue,
routing_key=inner_task_request_queue)
logger.info("Waiting for a request...")
logger.debug("Main pid = {}".format(PARENT_PROCESS_ID))
pid_list = []
# Start task termination proc
try:
task_kill_proc_hdl = mp.Process(target=recv_task_kill_request_proc)
task_kill_proc_hdl.start()
pid_list.append(task_kill_proc_hdl)
except Exception as e:
logger.exception("recv_task_kill_request_proc: {}".format(e))
proc_clean(pid_list)
conn_clean(conn, ch)
sys.exit(1)
# Start send_hb_to_client_proc proc
try:
recv_hb_from_client_proc_hdl = mp.Process(target=send_hb_to_client_proc,
args=(hearbeat_interval,
slurm_job_id,
mem_per_node_to_request,
mem_per_cpu_to_request,
num_cpus_to_request,
job_time_to_request,
worker_clone_time_rate,
inner_task_request_queue,
tp_name,
num_workers_per_node,
JTM_WORKER_HB_EXCH,
WORKER_HB_Q_POSTFIX))
recv_hb_from_client_proc_hdl.start()
pid_list.append(recv_hb_from_client_proc_hdl)
except Exception as e:
logger.exception("send_hb_to_client_proc: {}".format(e))
proc_clean(pid_list)
conn_clean(conn, ch)
sys.exit(1)
logger.info("Start sending my heartbeat to the client in every %d sec to %s"
% (hearbeat_interval, WORKER_HB_Q_POSTFIX))
# Start poison receive thread
try:
recv_poison_proc_hdl = mp.Process(target=recv_reproduce_or_die_proc,
args=(pool_name_param,
cluster_name,
mem_per_node_to_request,
mem_per_cpu_to_request,
num_nodes_to_request,
num_cpus_to_request,
job_time_to_request,
worker_clone_time_rate,
num_workers_per_node,
JTM_WORKER_POISON_EXCH,
JTM_WORKER_POISON_Q))
recv_poison_proc_hdl.start()
pid_list.append(recv_poison_proc_hdl)
except Exception as e:
logger.exception("recv_reproduce_or_die_proc: {}".format(e))
proc_clean(pid_list)
conn_clean(conn, ch)
sys.exit(1)
# Start hb send thread
try:
send_hb_to_client_proc_hdl = mp.Process(target=recv_hb_from_client_proc2,
args=(inner_task_request_queue,
JTM_CLIENT_HB_EXCH,
CLIENT_HB_Q_POSTFIX))
send_hb_to_client_proc_hdl.start()
pid_list.append(send_hb_to_client_proc_hdl)
except Exception as e:
logger.exception("Worker termination request: {}".format(e))
proc_clean(pid_list)
conn_clean(conn, ch)
sys.exit(1)
# Checking the total number of child processes
try:
check_processes_hdl = mp.Process(target=check_processes,
args=(pid_list,))
check_processes_hdl.start()
pid_list.append(check_processes_hdl)
except Exception as e:
logger.exception("check_processes: {}".format(e))
proc_clean(pid_list)
conn_clean(conn, ch)
sys.exit(1)
def signal_handler(signum, frame):
proc_clean(pid_list)
signal.signal(signal.SIGTERM, signal_handler)
# Waiting for request
ch.basic_qos(prefetch_count=1)
# OLD
# try:
# ch.basic_consume(queue=inner_task_request_queue,
# on_message_callback=do_work,
# auto_ack=False)
# except OSError as err:
# logger.exception("Worker terminated: {}".format(err))
# proc_clean()
# conn_clean()
# sys.exit(1)
# NEW
# Ref) https://github.com/pika/pika/blob/1.0.1/examples/
# https://stackoverflow.com/questions/51752890/how-to-disable-heartbeats-with-pika-and-rabbitmq
# https://github.com/pika/pika/blob/master/examples/basic_consumer_threaded.py
threads = []
on_message_callback = functools.partial(on_task_request,
args=(conn, threads))
ch.basic_consume(queue=inner_task_request_queue,
on_message_callback=on_message_callback)
try:
ch.start_consuming()
except KeyboardInterrupt:
proc_clean()
conn_clean()
# Wait for all to complete
# Note: prefetch_count=1 ==> #thread = 1
for thread in threads:
thread.join()
if conn:
conn.close()
return 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment