Skip to content

Instantly share code, notes, and snippets.

@shunghsiyu
Forked from somic/udp_hole_punch_tester.py
Last active June 7, 2021 07:37
Show Gist options
  • Save shunghsiyu/eafd18134e747f211560be316d5ff985 to your computer and use it in GitHub Desktop.
Save shunghsiyu/eafd18134e747f211560be316d5ff985 to your computer and use it in GitHub Desktop.
Scripts to setup mosh connection with server behind firewall

mosh-nat

Scripts to setup mosh connection with server behind firewall

Setup

  • Install the stuntman package to get the stunclient binary both locally and on the server
  • Place the udp_hole_punch script inside a directory that is in $PATH so that it can be called; both locally and on the server
  • Compile mosh-nat-bind.c into mnb.so and place it inside $HOME/bin/ locally (don't need it on server)

Usage

mosh-nat $server

Where $server is what you normally give to the ssh command (e.g. user@host, or better yet, referencing a server inside your SSH config).

#!/usr/bin/env python3
import asyncio
import collections
import os
import random
import re
import shlex
import subprocess
import sys
STUN_SERVER = 'stun.l.google.com'
STUN_PORT = '19302'
LOCAL_ADDRESS_RE = re.compile(r'Local address: (?P<ip>[\d.]+):(?P<port>\d+)')
MAPPED_ADDRESS_RE = re.compile(r'Mapped address: (?P<ip>[\d.]+):(?P<port>\d+)')
MOSH_KEY_RE = re.compile(r'MOSH CONNECT \d+ (?P<key>.+)')
Address = collections.namedtuple(
'Address',
['ip', 'port', 'mapped_ip', 'mapped_port'],
)
def parse_stun(stdout):
local_match = LOCAL_ADDRESS_RE.search(stdout)
mapped_match = MAPPED_ADDRESS_RE.search(stdout)
if not local_match:
raise RuntimeError('Cannot find local address and port,', stdout)
if not mapped_match:
raise RuntimeError('Cannot find mapped address and port,', stdout)
address = Address(
local_match.group('ip'),
local_match.group('port'),
mapped_match.group('ip'),
mapped_match.group('port'),
)
return address
async def find_local_mapping():
proc = await asyncio.create_subprocess_exec(
'stunclient',
STUN_SERVER,
STUN_PORT,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError('Failed to run stunclient locally,', stderr)
return parse_stun(stdout.decode())
async def find_remote_mapping(target):
remote_cmd = shlex.join(['stunclient', STUN_SERVER, STUN_PORT])
proc = await asyncio.create_subprocess_exec(
'ssh',
'-v',
target,
remote_cmd,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=None,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError('Failed to run stunclient remotely at', target)
return parse_stun(stdout.decode())
async def punch_local(local_port, remote_mapped_ip, remote_mapped_port):
proc = await asyncio.create_subprocess_exec(
'udp_hole_punch',
local_port,
remote_mapped_ip,
remote_mapped_port,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError('Failed to run udp_hole_punch locally,', stderr)
async def punch_remote(target, remote_port, local_mapped_ip, local_mapped_port):
remote_cmd = shlex.join([
'udp_hole_punch',
remote_port,
local_mapped_ip,
local_mapped_port,
])
proc = await asyncio.create_subprocess_exec(
'ssh',
target,
remote_cmd,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError('Failed to run udp_hole_punch remotely,', stderr)
async def start_server(target, port):
remote_cmd = shlex.join([
'mosh-server',
'new',
'-p',
port,
])
proc = await asyncio.create_subprocess_exec(
'ssh',
target,
remote_cmd,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError('Failed to start mosh-server remotely,', stderr)
match = MOSH_KEY_RE.search(stdout.decode())
if not match:
raise RuntimeError('Failed to extract mosh-server key', stdout)
return match.group('key')
async def setup(target):
local, remote = await asyncio.gather(
find_local_mapping(),
find_remote_mapping(target),
)
await asyncio.gather(
punch_local(local.port, remote.mapped_ip, remote.mapped_port),
punch_remote(target, remote.port, local.mapped_ip, local.mapped_port),
)
print('UDP hole punched', local, remote)
key = await start_server(target, remote.port)
return local.port, key, remote.mapped_ip, remote.mapped_port
def main(target):
local_port, key, server_ip, server_port = asyncio.run(setup(target))
os.putenv('MNB_PORT', local_port)
os.putenv('LD_PRELOAD', os.path.join(os.environ['HOME'], 'bin', 'mnb.so'))
os.putenv('MOSH_KEY', key)
print(local_port, key, server_ip, server_port)
os.execvp(
'mosh-client',
['mosh-client', server_ip, server_port],
)
if __name__ == '__main__':
main(sys.argv[1])
/*
Based on: https://raw.githubusercontent.com/yongboy/bindp/
Original Copyright (C) 2014 nieyong
email: nieyong@staff.weibo.com
web: http://www.blogjava.net/yongboy
License: LGPL-2.1
*/
/*
LD_PRELOAD library to override bind() and sendto(),
forcing bind() to use specific options depending on env vars:
- MNB_IPV4=1.2.3.4 - specified IPv4 address.
- MNB_PORT=34730 - specified port.
- MNB_REUSE_ADDR=1 - SO_REUSEADDR option - socket(7).
- MNB_REUSE_PORT=1 - SO_REUSEPORT option - socket(7).
- MNB_IP_TRANSPARENT=1 - IP_TRANSPARENT option - ip(7).
Limitations (hacks):
- Only binds IPv4 (AF_INET) sockets.
- Tracks last fd used in sendto(fd, ...) and does bind() for it once.
Here to force mosh-client to connect from specified local port.
Compile on Linux (>=3.9) with:
gcc -nostartfiles -fpic -shared \
-ldl -D_GNU_SOURCE mosh-nat-bind.c -o mnb.so
Usage example relevant to mosh-client:
MNB_PORT=34731 LD_PRELOAD=./mnb.so \
MOSH_KEY=KBwxklPHaRqV7l5OgE3OsA mosh-client 10.0.1.13 34732
(connecting from 10.0.1.1:3731 to 10.0.1.13:3732)
*/
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <dlfcn.h>
#include <errno.h>
int (*real_bind)(int, const struct sockaddr *, socklen_t);
int (*real_sendto)( int fd, const void *message,
size_t length, int flags, const struct sockaddr *sk, socklen_t dest_len );
unsigned long int bind_addr_saddr = 0;
struct sockaddr_in local_sockaddr_in[] = { 0 };
unsigned int bind_port_saddr = 0;
unsigned int reuse_port = 0;
unsigned int reuse_addr = 0;
unsigned int ip_transparent = 0;
int bind_last_fd = -1;
void _init(void){
const char *err;
real_bind = dlsym(RTLD_NEXT, "bind");
if ((err = dlerror()) != NULL) fprintf(stderr, "dlsym (bind): %s\n", err);
real_sendto = dlsym(RTLD_NEXT, "sendto");
if ((err = dlerror()) != NULL) fprintf(stderr, "dlsym (sendto): %s\n", err);
char *bind_addr_env;
if ((bind_addr_env = getenv("MNB_IPV4"))) {
bind_addr_saddr = inet_addr(bind_addr_env);
local_sockaddr_in->sin_family = AF_INET;
local_sockaddr_in->sin_addr.s_addr = bind_addr_saddr;
local_sockaddr_in->sin_port = htons(0);
}
char *bind_port_env;
if ((bind_port_env = getenv("MNB_PORT"))) {
bind_port_saddr = atoi(bind_port_env);
local_sockaddr_in->sin_port = htons(bind_port_saddr);
}
char *reuse_addr_env;
if ((reuse_addr_env = getenv("MNB_REUSE_ADDR")))
reuse_addr = atoi(reuse_addr_env);
char *reuse_port_env;
if ((reuse_port_env = getenv("MNB_REUSE_PORT")))
reuse_port = atoi(reuse_port_env);
char *ip_transparent_env;
if ((ip_transparent_env = getenv("MNB_IP_TRANSPARENT")))
ip_transparent = atoi(ip_transparent_env);
}
int bind(int fd, const struct sockaddr *sk, socklen_t sl) {
static struct sockaddr_in *lsk_in;
lsk_in = (struct sockaddr_in *)sk;
if (bind_addr_saddr) lsk_in->sin_addr.s_addr = bind_addr_saddr;
if (bind_port_saddr) lsk_in->sin_port = htons(bind_port_saddr);
if (reuse_addr)
setsockopt( fd, SOL_SOCKET,
SO_REUSEADDR, &reuse_addr, sizeof(reuse_addr) );
if (reuse_port)
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &reuse_port, sizeof(reuse_port));
if (ip_transparent)
setsockopt( fd, SOL_IP,
IP_TRANSPARENT, &ip_transparent, sizeof(ip_transparent) );
return real_bind(fd, sk, sl);
}
ssize_t sendto(
int fd, const void *message, size_t length,
int flags, const struct sockaddr *sk, socklen_t dest_len ) {
static struct sockaddr_in *rsk_in;
rsk_in = (struct sockaddr_in *)sk;
if (bind_last_fd != fd) {
bind_last_fd = fd;
if ( (rsk_in->sin_family == AF_INET)
&& (bind_addr_saddr || bind_port_saddr) )
bind(fd, (struct sockaddr *)local_sockaddr_in, sizeof (struct sockaddr));
}
return real_sendto(fd, message, length, flags, sk, dest_len);
}
#!/usr/bin/env python3
#
# udp_hole_punch - UDP Hole Punching test tool
#
# Usage: udp_hole_punch local_port remote_host remote_port
#
# Run this script simultaneously on 2 hosts to test if they can punch
# a UDP hole to each other.
#
# * if local_port < 1024, must be root.
#
# Copyright (C) 2009 Dmitriy Samovskiy, http://somic.org
# Copyright (C) 2020 Shung-Hsi Yu, https://shungh.si
#
# License: Apache License, Version 2.0
# http://www.apache.org/licenses/
#
import sys, os, time, socket, random
from select import select
def log(*args):
print(time.asctime(), ' '.join([str(x) for x in args]))
def puncher(local_port, remote_host, remote_port):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('', local_port))
my_token = str(random.random())
log("my_token =", my_token)
remote_token = "_"
sock.setblocking(0)
sock.settimeout(5)
remote_knows_our_token = False
for i in range(60):
r,w,x = select([sock], [sock], [], 0)
if remote_token != "_" and remote_knows_our_token:
log("we are done - hole was punched from both ends")
break
if r:
data, addr = sock.recvfrom(1024)
log("recv:", data)
if remote_token == "_":
remote_token = data.split()[0]
log("remote_token is now", remote_token)
if len(data.split()) == 3:
log("remote end signals it knows our token")
remote_knows_our_token = True
if w:
data = "%s %s" % (my_token, remote_token)
if remote_token != "_": data += " ok"
log("sending:", data)
sock.sendto(data.encode('ascii'), (remote_host, remote_port))
log("sent", i)
time.sleep(0.5)
log("done")
sock.close()
return remote_token != "_"
if __name__ == '__main__':
local_port = int(sys.argv[1])
remote_host = sys.argv[2]
remote_port = int(sys.argv[3])
if puncher(local_port, remote_host, remote_port):
log("Punched UDP hole to %s:%d successfully" % (remote_host, remote_port))
else:
log("Failed to punch hole")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment