Last active
April 24, 2019 02:29
-
-
Save yaroslavvb/6cff9a26b74c4f47119d81309215b579 to your computer and use it in GitHub Desktop.
example of using TMUX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import sys | |
from six.moves import shlex_quote | |
parser = argparse.ArgumentParser(description="Run commands") | |
parser.add_argument('-w', '--num-workers', default=1, type=int, | |
help="Number of workers") | |
parser.add_argument('-r', '--remotes', default=None, | |
help='The address of pre-existing VNC servers and ' | |
'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).') | |
parser.add_argument('-e', '--env-id', type=str, default="PongDeterministic-v3", | |
help="Environment id") | |
parser.add_argument('-l', '--log-dir', type=str, default="/tmp/pong", | |
help="Log directory path") | |
parser.add_argument('-n', '--dry-run', action='store_true', | |
help="Print out commands rather than executing them") | |
parser.add_argument('-m', '--mode', type=str, default='tmux', | |
help="tmux: run workers in a tmux session. nohup: run workers with nohup. child: run workers as child processes") | |
parser.add_argument('-rl', '--run-local', type=bool, default=False, | |
help="Whether or not run locally") | |
# Add visualise tag | |
parser.add_argument('--visualise', action='store_true', | |
help="Visualise the gym environment by running env.render() between each timestep") | |
def new_cmd(session, name, cmd, mode, logdir, shell): | |
if isinstance(cmd, (list, tuple)): | |
cmd = " ".join(shlex_quote(str(v)) for v in cmd) | |
if mode == 'tmux': | |
return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd)) | |
elif mode == 'child': | |
return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir) | |
elif mode == 'nohup': | |
return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir) | |
# name used for tmux session | |
SESSION_NAME = "a3c" | |
def ossystem(cmd): | |
print(cmd) | |
os.system(cmd) | |
initialized_windows = set() | |
def run_in_window(window, cmd_list): | |
"""Runs command in tmux window, initializing tmux session and window if | |
necessary. cmd_list is list of args""" | |
global initialized_windows | |
def run(cmd): | |
ossystem("tmux send-keys -t {} '{}' Enter".format(window, cmd)) | |
# if nothing initialized, restart session | |
if not initialized_windows: | |
ossystem('tmux kill-session -t' + SESSION_NAME) | |
# -d starts new session in detached mode | |
# since can't start windowless tmux, start with dummy window and rename | |
# later | |
ossystem('tmux new-session -s %s -n %s -d '% (SESSION_NAME, "blargh")) | |
if not window in initialized_windows: | |
if not initialized_windows: | |
ossystem('tmux rename-window -t blargh '+window) | |
else: | |
ossystem("tmux new-window -t {} -n {}".format(SESSION_NAME, window)) | |
# opencv causes segfault when used with TF GPU, turn off OpenCL to fix | |
run("export OPENCV_OPENCL_RUNTIME=") | |
run("export CUDA_VISIBLE_DEVICES=2") | |
initialized_windows.add(window) | |
run(' '.join(cmd_list)) | |
def launch_training(num_workers, remotes, env_id, logdir): | |
"""Launches PS and worker processes. | |
Args: | |
num_workers: number of worker (non-ps processes) to launch | |
remotes: .... | |
env_id: gym env id, ie PongDeterministic-v3 | |
logdir: TensorFlow logdir | |
""" | |
ossystem('mkdir -p '+logdir) | |
# command shared between worker and ps runs | |
base_cmd = [ | |
sys.executable, 'worker.py', | |
'--log-dir', logdir, '--env-id', env_id, | |
'--num-workers', str(num_workers)] | |
# start parameter server | |
run_in_window("ps", base_cmd + ["--job-name", "ps"]) | |
# start workers | |
if remotes is None: | |
remotes = ["1"] * num_workers | |
else: | |
remotes = remotes.split(',') | |
assert len(remotes) == num_workers | |
cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)] | |
for i in range(num_workers): | |
cmds_map += [new_cmd(session, | |
"w-%d" % i, base_cmd + ["--job-name", "worker", "--task", str(i), "--remotes", remotes[i]], mode, logdir, shell)] | |
cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir, "--port", "12345"], mode, logdir, shell)] | |
if mode == 'tmux': | |
cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)] | |
for i in range(num_workers): | |
base_cmd = [ | |
# sys.executable, '-m pdb', 'worker.py', | |
sys.executable, '-u', 'worker.py', | |
'--log-dir', logdir, '--env-id', env_id, | |
'--num-workers', str(num_workers)] | |
worker_cmd = base_cmd + ["--job-name", "worker", "--task", str(i), | |
"--remotes", remotes[i]] | |
run_in_window("w-"+str(i), worker_cmd) | |
# start TensorBoard | |
run_in_window('tb', ["tensorboard --logdir {} --port 12344".format(logdir)]) | |
# attach to worker0 | |
ossystem('tmux select-window -t %s:%s'%(SESSION_NAME, 'w-0')) | |
ossystem('tmux a -t '+SESSION_NAME) | |
notes = [] | |
cmds = [ | |
"mkdir -p {}".format(logdir), | |
"echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir), | |
] | |
if mode == 'nohup' or mode == 'child': | |
cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)] | |
notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)] | |
if mode == 'tmux': | |
notes += ["Use `tmux attach -t {}` to watch process output".format(session)] | |
notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)] | |
else: | |
notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)] | |
notes += ["Point your browser to http://localhost:12345 to see Tensorboard"] | |
if mode == 'tmux': | |
cmds += [ | |
"kill $( lsof -i:12345 -t ) > /dev/null 2>&1", # kill any process using tensorboard's port | |
"kill $( lsof -i:12222-{} -t ) > /dev/null 2>&1".format(num_workers+12222), # kill any processes using ps / worker ports | |
"tmux kill-session -t {}".format(session), | |
"tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell) | |
] | |
for w in windows[1:]: | |
cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)] | |
cmds += ["sleep 1"] | |
for window, cmd in cmds_map: | |
cmds += [cmd] | |
return cmds, notes | |
def run(): | |
args = parser.parse_args() | |
cmds, notes = create_commands("a3c", args.num_workers, args.remotes, args.env_id, args.log_dir, mode=args.mode, visualise=args.visualise) | |
if args.dry_run: | |
print("Dry-run mode due to -n flag, otherwise the following commands would be executed:") | |
else: | |
print("Executing the following commands:") | |
print("\n".join(cmds)) | |
print("") | |
if not args.dry_run: | |
if args.mode == "tmux": | |
os.environ["TMUX"] = "" | |
os.system("\n".join(cmds)) | |
print('\n'.join(notes)) | |
def main(): | |
args = parser.parse_args() | |
launch_training(args.num_workers, args.remotes, args.env_id, args.log_dir) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment