-
-
Save admk/ecd4bf735c025c48bd45ecc27c908804 to your computer and use it in GitHub Desktop.
Reverse SSH tunneling to SLURM node using AutoSSH
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import shutil | |
import socket | |
import getpass | |
import datetime | |
import argparse | |
import subprocess | |
import tempfile | |
LOGDIR = os.path.join(os.path.abspath(os.curdir), 'logs') | |
os.makedirs(LOGDIR, exist_ok=True) | |
SLURM_SCRIPT = """#!/bin/bash | |
#SBATCH -p {args.partition} | |
#SBATCH -o logs/%j.out | |
#SBATCH -e logs/%j.err | |
#SBATCH -N 1 | |
#SBATCH --cpus-per-task={args.cpu_count} | |
#SBATCH --time=7-0:00 | |
#SBATCH --gres={args.gres_type}:{args.gres_count} | |
cd $HOME | |
# autossh to keep the reverse ssh connection alive | |
{autossh} \\ | |
-M 0 \\ | |
-o 'ServerAliveInterval 30' -o 'ServerAliveCountMax 3' -N \\ | |
-o ExitOnForwardFailure=yes -v \\ | |
-R {port}:localhost:22 \\ | |
{args.host} | |
""" | |
def check_connection(port): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
return s.connect_ex(('localhost', port)) == 0 | |
def first_unused_port(): | |
for port in range(30001, 31000): | |
if not check_connection(port): | |
return port | |
raise RuntimeError("No unused port available.") | |
def submit_slurm_job(args): | |
print(f"date: {datetime.datetime.now():%Y-%m-%d %H:%M:%S}") | |
if args.port_number: | |
if not check_connection(args.port_number): | |
raise RuntimeError( | |
f"Port {args.port_number} is not available, " | |
f"please choose another port.") | |
port = args.port_number | |
else: | |
port = first_unused_port() | |
autossh = shutil.which('autossh') | |
script = SLURM_SCRIPT.format(autossh=autossh, args=args, port=port) | |
print(f"args: {args!r}") | |
print(f"port: {port!r}") | |
print(f"autossh: {autossh!r}") | |
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: | |
f.write(script) | |
print(f"--- SLURM SCRIPT [{f.name}] ---") | |
print(script) | |
print("--- END SLURM SCRIPT ---") | |
cmd = ['sbatch'] + list(args.sbatch_args.split()) + [f.name] | |
print(f"command: {' '.join(cmd)}") | |
subprocess.run(cmd) | |
os.remove(f.name) | |
def local_ip(): | |
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: | |
s.connect(('8.8.8.8', 80)) | |
return s.getsockname()[0] | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Submit a reverse SSH job to SLURM.') | |
default_host = f'{getpass.getuser()}@{local_ip()}' | |
parser.add_argument( | |
'-s', '--host', type=str, default=default_host, | |
help=f'Host to connect to, default: {default_host!r}.') | |
parser.add_argument( | |
'-t', '--partition', type=str, default='gpu', | |
help='Partition to submit to.') | |
parser.add_argument( | |
'-c', '--cpu-count', type=int, default=16, help='Number of CPU cores.') | |
parser.add_argument( | |
'-r', '--gres-type', type=str, default='gpu:a100-sxm4-80gb') | |
parser.add_argument( | |
'-g', '--gres-count', type=int, default=1, help='Number of Gres.') | |
parser.add_argument( | |
'-p', '--port-number', type=int, default=None, | |
help='Port number on host to expose, assigned if not specified.') | |
parser.add_argument( | |
'-a', '--sbatch-args', type=str, default='', | |
help='Additional arguments to pass to sbatch.') | |
return parser.parse_args() | |
def main(args=None): | |
submit_slurm_job(args or parse_args()) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment