Skip to content

Instantly share code, notes, and snippets.

@admk
Last active December 28, 2023 00:25
Show Gist options
  • Save admk/ecd4bf735c025c48bd45ecc27c908804 to your computer and use it in GitHub Desktop.
Save admk/ecd4bf735c025c48bd45ecc27c908804 to your computer and use it in GitHub Desktop.
Reverse SSH tunneling to SLURM node using AutoSSH
#!/usr/bin/env python3
import os
import shutil
import socket
import getpass
import datetime
import argparse
import subprocess
import tempfile
LOGDIR = os.path.join(os.path.abspath(os.curdir), 'logs')
os.makedirs(LOGDIR, exist_ok=True)
SLURM_SCRIPT = """#!/bin/bash
#SBATCH -p {args.partition}
#SBATCH -o logs/%j.out
#SBATCH -e logs/%j.err
#SBATCH -N 1
#SBATCH --cpus-per-task={args.cpu_count}
#SBATCH --time=7-0:00
#SBATCH --gres={args.gres_type}:{args.gres_count}
cd $HOME
# autossh to keep the reverse ssh connection alive
{autossh} \\
-M 0 \\
-o 'ServerAliveInterval 30' -o 'ServerAliveCountMax 3' -N \\
-o ExitOnForwardFailure=yes -v \\
-R {port}:localhost:22 \\
{args.host}
"""
def check_connection(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
def first_unused_port():
for port in range(30001, 31000):
if not check_connection(port):
return port
raise RuntimeError("No unused port available.")
def submit_slurm_job(args):
print(f"date: {datetime.datetime.now():%Y-%m-%d %H:%M:%S}")
if args.port_number:
if not check_connection(args.port_number):
raise RuntimeError(
f"Port {args.port_number} is not available, "
f"please choose another port.")
port = args.port_number
else:
port = first_unused_port()
autossh = shutil.which('autossh')
script = SLURM_SCRIPT.format(autossh=autossh, args=args, port=port)
print(f"args: {args!r}")
print(f"port: {port!r}")
print(f"autossh: {autossh!r}")
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
f.write(script)
print(f"--- SLURM SCRIPT [{f.name}] ---")
print(script)
print("--- END SLURM SCRIPT ---")
cmd = ['sbatch'] + list(args.sbatch_args.split()) + [f.name]
print(f"command: {' '.join(cmd)}")
subprocess.run(cmd)
os.remove(f.name)
def local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
def parse_args():
parser = argparse.ArgumentParser(
description='Submit a reverse SSH job to SLURM.')
default_host = f'{getpass.getuser()}@{local_ip()}'
parser.add_argument(
'-s', '--host', type=str, default=default_host,
help=f'Host to connect to, default: {default_host!r}.')
parser.add_argument(
'-t', '--partition', type=str, default='gpu',
help='Partition to submit to.')
parser.add_argument(
'-c', '--cpu-count', type=int, default=16, help='Number of CPU cores.')
parser.add_argument(
'-r', '--gres-type', type=str, default='gpu:a100-sxm4-80gb')
parser.add_argument(
'-g', '--gres-count', type=int, default=1, help='Number of Gres.')
parser.add_argument(
'-p', '--port-number', type=int, default=None,
help='Port number on host to expose, assigned if not specified.')
parser.add_argument(
'-a', '--sbatch-args', type=str, default='',
help='Additional arguments to pass to sbatch.')
return parser.parse_args()
def main(args=None):
submit_slurm_job(args or parse_args())
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment