Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Last active April 24, 2019 02:29
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaroslavvb/6cff9a26b74c4f47119d81309215b579 to your computer and use it in GitHub Desktop.
Save yaroslavvb/6cff9a26b74c4f47119d81309215b579 to your computer and use it in GitHub Desktop.
example of using TMUX
import argparse
import os
import sys
from six.moves import shlex_quote
parser = argparse.ArgumentParser(description="Run commands")
parser.add_argument('-w', '--num-workers', default=1, type=int,
help="Number of workers")
parser.add_argument('-r', '--remotes', default=None,
help='The address of pre-existing VNC servers and '
'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).')
parser.add_argument('-e', '--env-id', type=str, default="PongDeterministic-v3",
help="Environment id")
parser.add_argument('-l', '--log-dir', type=str, default="/tmp/pong",
help="Log directory path")
parser.add_argument('-n', '--dry-run', action='store_true',
help="Print out commands rather than executing them")
parser.add_argument('-m', '--mode', type=str, default='tmux',
help="tmux: run workers in a tmux session. nohup: run workers with nohup. child: run workers as child processes")
parser.add_argument('-rl', '--run-local', type=bool, default=False,
help="Whether or not run locally")
# Add visualise tag
parser.add_argument('--visualise', action='store_true',
help="Visualise the gym environment by running env.render() between each timestep")
def new_cmd(session, name, cmd, mode, logdir, shell):
if isinstance(cmd, (list, tuple)):
cmd = " ".join(shlex_quote(str(v)) for v in cmd)
if mode == 'tmux':
return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd))
elif mode == 'child':
return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir)
elif mode == 'nohup':
return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir)
# name used for tmux session
SESSION_NAME = "a3c"
def ossystem(cmd):
print(cmd)
os.system(cmd)
initialized_windows = set()
def run_in_window(window, cmd_list):
"""Runs command in tmux window, initializing tmux session and window if
necessary. cmd_list is list of args"""
global initialized_windows
def run(cmd):
ossystem("tmux send-keys -t {} '{}' Enter".format(window, cmd))
# if nothing initialized, restart session
if not initialized_windows:
ossystem('tmux kill-session -t' + SESSION_NAME)
# -d starts new session in detached mode
# since can't start windowless tmux, start with dummy window and rename
# later
ossystem('tmux new-session -s %s -n %s -d '% (SESSION_NAME, "blargh"))
if not window in initialized_windows:
if not initialized_windows:
ossystem('tmux rename-window -t blargh '+window)
else:
ossystem("tmux new-window -t {} -n {}".format(SESSION_NAME, window))
# opencv causes segfault when used with TF GPU, turn off OpenCL to fix
run("export OPENCV_OPENCL_RUNTIME=")
run("export CUDA_VISIBLE_DEVICES=2")
initialized_windows.add(window)
run(' '.join(cmd_list))
def launch_training(num_workers, remotes, env_id, logdir):
"""Launches PS and worker processes.
Args:
num_workers: number of worker (non-ps processes) to launch
remotes: ....
env_id: gym env id, ie PongDeterministic-v3
logdir: TensorFlow logdir
"""
ossystem('mkdir -p '+logdir)
# command shared between worker and ps runs
base_cmd = [
sys.executable, 'worker.py',
'--log-dir', logdir, '--env-id', env_id,
'--num-workers', str(num_workers)]
# start parameter server
run_in_window("ps", base_cmd + ["--job-name", "ps"])
# start workers
if remotes is None:
remotes = ["1"] * num_workers
else:
remotes = remotes.split(',')
assert len(remotes) == num_workers
cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)]
for i in range(num_workers):
cmds_map += [new_cmd(session,
"w-%d" % i, base_cmd + ["--job-name", "worker", "--task", str(i), "--remotes", remotes[i]], mode, logdir, shell)]
cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir, "--port", "12345"], mode, logdir, shell)]
if mode == 'tmux':
cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)]
for i in range(num_workers):
base_cmd = [
# sys.executable, '-m pdb', 'worker.py',
sys.executable, '-u', 'worker.py',
'--log-dir', logdir, '--env-id', env_id,
'--num-workers', str(num_workers)]
worker_cmd = base_cmd + ["--job-name", "worker", "--task", str(i),
"--remotes", remotes[i]]
run_in_window("w-"+str(i), worker_cmd)
# start TensorBoard
run_in_window('tb', ["tensorboard --logdir {} --port 12344".format(logdir)])
# attach to worker0
ossystem('tmux select-window -t %s:%s'%(SESSION_NAME, 'w-0'))
ossystem('tmux a -t '+SESSION_NAME)
notes = []
cmds = [
"mkdir -p {}".format(logdir),
"echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir),
]
if mode == 'nohup' or mode == 'child':
cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)]
notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)]
if mode == 'tmux':
notes += ["Use `tmux attach -t {}` to watch process output".format(session)]
notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)]
else:
notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)]
notes += ["Point your browser to http://localhost:12345 to see Tensorboard"]
if mode == 'tmux':
cmds += [
"kill $( lsof -i:12345 -t ) > /dev/null 2>&1", # kill any process using tensorboard's port
"kill $( lsof -i:12222-{} -t ) > /dev/null 2>&1".format(num_workers+12222), # kill any processes using ps / worker ports
"tmux kill-session -t {}".format(session),
"tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell)
]
for w in windows[1:]:
cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)]
cmds += ["sleep 1"]
for window, cmd in cmds_map:
cmds += [cmd]
return cmds, notes
def run():
args = parser.parse_args()
cmds, notes = create_commands("a3c", args.num_workers, args.remotes, args.env_id, args.log_dir, mode=args.mode, visualise=args.visualise)
if args.dry_run:
print("Dry-run mode due to -n flag, otherwise the following commands would be executed:")
else:
print("Executing the following commands:")
print("\n".join(cmds))
print("")
if not args.dry_run:
if args.mode == "tmux":
os.environ["TMUX"] = ""
os.system("\n".join(cmds))
print('\n'.join(notes))
def main():
args = parser.parse_args()
launch_training(args.num_workers, args.remotes, args.env_id, args.log_dir)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment