Skip to content

Instantly share code, notes, and snippets.

@klustic
Created August 4, 2021 17:52
Show Gist options
  • Save klustic/c3b44a89e41768b5e9daecec32aea98f to your computer and use it in GitHub Desktop.
Save klustic/c3b44a89e41768b5e9daecec32aea98f to your computer and use it in GitHub Desktop.
nunya
#!/usr/bin/env python
import logging
import random
import select
import shlex
import signal
import socket
import ssl
import struct
import sys
MTYPE_NOOP = 0x00 # No-op. Used for keepalive messages
MTYPE_COPEN = 0x01 # Open Channel messages
MTYPE_CCLO = 0x02 # Close Channel messages
MTYPE_CADDR = 0x03 # Channel Address (remote endpoint address info)
MTYPE_DATA = 0x10 # Data messages
def recvall(s, size):
data = ''
while len(data) < size:
d = s.recv(size - len(data))
if not d:
break
data += d
return data
def integer_generator(seed=random.randint(0, 0xffffffff)):
while True:
seed = (seed + 1) % 0xffffffff
yield seed
class Message(object):
""" Container class with (un)serialization methods """
M_HDR_STRUCT = struct.Struct('!BII') # Message Type | Channel ID | Payload Size
def __init__(self, mtype=MTYPE_NOOP, channel=0, size=0):
self.mtype = mtype
self.channel = channel
self.size = size
def __str__(self):
return '<Message type={0} channel={1}>'.format(self.mtype, self.channel)
@classmethod
def unpack(cls, data):
if len(data) < cls.M_HDR_STRUCT.size:
raise ValueError('Attempting to unpack a Message header from too little data')
return Message(*cls.M_HDR_STRUCT.unpack(data[:cls.M_HDR_STRUCT.size])), data[cls.M_HDR_STRUCT.size:]
def pack(self, data=''):
self.size = len(data)
return self.M_HDR_STRUCT.pack(self.mtype, self.channel, self.size) + data
class Channel(object):
""" Container class with remote socket and channel id """
def __init__(self):
self.socket = None # type: socket.socket
self.channel_id = None
self.remote_peer_addr = None
self.local_peer_addr = None
self.socks_handler = SocksHandler()
self.logger = logging.getLogger(self.__class__.__name__)
def __str__(self):
return '<Channel id={0} remote_addr={1} local_addr={2}>'.format(self.channel_id, self.remote_peer_addr, self.local_peer_addr)
@property
def connected(self):
return isinstance(self.socket, socket.socket)
def fileno(self):
return self.socket.fileno()
def close(self):
self.logger.debug('Closing channel {0}'.format(self))
if self.connected:
try:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
except Exception as e:
self.logger.debug('Unable to close channel: {0}'.format(e))
self.socket = None
class Tunnel(object):
""" Container class with connected transport socket, list of Channels, and methods for passing Messages """
def __init__(self, transport_socket):
self.channels = [] # List[Channel]
self.transport_socket = transport_socket # type: socket.socket
self.logger = logging.getLogger(self.__class__.__name__)
def send_message(self, msg, data=''):
self.logger.debug('Sending {0}'.format(msg))
try:
self.transport_socket.sendall(msg.pack(data))
except (socket.error, TypeError) as e:
self.logger.critical('Problem sending a message over transport: {0}'.format(e))
sys.exit(255)
def recv_message(self):
try:
msg, _ = Message.unpack(recvall(self.transport_socket, Message.M_HDR_STRUCT.size))
except socket.error as e:
self.logger.critical('Problem receiving a message over transport: {0}'.format(e))
sys.exit(255)
return msg, recvall(self.transport_socket, msg.size)
def get_channel_by_id(self, channel_id):
for c in self.channels:
if c.channel_id == channel_id:
return c
raise KeyError('Invalid channel number "{0}"'.format(channel_id))
def open_channel(self, channel_id, remote=False):
c = Channel()
c.channel_id = channel_id
self.channels.append(c)
if remote:
msg = Message(mtype=MTYPE_COPEN, channel=c.channel_id)
self.send_message(msg)
return c
def close_channel(self, channel_id, remote=False):
for c in self.channels:
if c.channel_id == channel_id:
c.close()
self.channels.remove(c)
self.logger.info('Closed channel: {0}'.format(c))
break
if remote:
msg = Message(mtype=MTYPE_CCLO, channel=channel_id)
self.send_message(msg)
return
class SocksHandler(object):
SOCKS5_AUTH_METHODS = {
0x00: 'No Authentication Required',
0x01: 'GSSAPI',
0x02: 'USERNAME/PASSWORD',
0xFF: 'NO ACCEPTABLE METHODS'
}
def __init__(self):
self.auth_handled = False
self.request_handled = False
self.logger = logging.getLogger(self.__class__.__name__)
def handle(self, channel, data):
# SOCKSv5 Auth message
if not self.auth_handled:
data = [ord(x) for x in data]
# Expecting [VERSION | NMETHODS | METHODS] (VERSION must be 0x05)
if len(data) < 2 or data[0] != 0x05 or len(data[2:]) != data[1]:
return struct.pack('BB', 0x05, 0xFF) # No Acceptable Auth Methods
methods = [self.SOCKS5_AUTH_METHODS.get(x, hex(x)) for x in data[2:]]
self.logger.debug('Received SOCKS auth request: {0}'.format(', '.join(methods)))
self.auth_handled = True
return struct.pack('BB', 0x05, 0x00) # No Auth Required
elif not self.request_handled:
if len(data) < 4 or ord(data[0]) != 0x05:
return struct.pack('!BBBBIH', 0x05, 0x01, 0x00, 0x01, 0, 0) # General SOCKS failure
cmd = ord(data[1])
rsv = ord(data[2])
atyp = ord(data[3])
if cmd not in [0x01, 0x02, 0x03]:
return struct.pack('!BBBBIH', 0x05, 0x07, 0x00, 0x01, 0, 0) # Command not supported
if rsv != 0x00:
return struct.pack('!BBBBIH', 0x05, 0x01, 0x00, 0x01, 0, 0) # General SOCKS failure
if atyp not in [0x01, 0x03, 0x04]:
return struct.pack('!BBBBIH', 0x05, 0x08, 0x00, 0x01, 0, 0) # Address type not supported
if cmd == 0x01: # CONNECT
if atyp == 0x01: # IPv4
if len(data) != 10:
return struct.pack('!BBBBIH', 0x05, 0x01, 0x00, 0x01, 0, 0) # General SOCKS failure
host = socket.inet_ntop(socket.AF_INET, data[4:8])
port, = struct.unpack('!H', data[-2:])
af = socket.AF_INET
elif atyp == 0x03: # FQDN
size = ord(data[4])
if len(data[5:]) != size + 2:
return struct.pack('!BBBBIH', 0x05, 0x01, 0x00, 0x01, 0, 0) # General SOCKS failure
host = data[5:5+size]
port, = struct.unpack('!H', data[-2:])
af = socket.AF_INET
atyp = 0x01
elif atyp == 0x04: # IPv6
if len(data) != 22:
return struct.pack('!BBBBIH', 0x05, 0x01, 0x00, 0x01, 0, 0) # General SOCKS failure
host = socket.inet_ntop(socket.AF_INET6, data[5:21])
port, = struct.unpack('!H', data[-2:])
af = socket.AF_INET6
else:
raise NotImplementedError('Failed to implement handler for atype={0}'.format(hex(atyp)))
self.logger.debug('Received SOCKSv5 CONNECT request for {0}:{1}'.format(host, port))
try:
s = socket.socket(af)
s.settimeout(2)
s.connect((host, port))
except socket.timeout:
return struct.pack('!BBBBIH', 0x05, 0x04, 0x00, 0x01, 0, 0) # host unreachable
except socket.error:
return struct.pack('!BBBBIH', 0x05, 0x05, 0x00, 0x01, 0, 0) # connection refused
except Exception:
return struct.pack('!BBBBIH', 0x05, 0x01, 0x00, 0x01, 0, 0) # General SOCKS failure
s.settimeout(None)
channel.socket = s
peer_host, peer_port = s.getpeername()[:2]
channel.local_peer_addr = '{0}[{1}]:{2}'.format(host, peer_host, port)
local_host, local_port = s.getsockname()[:2]
bind_addr = socket.inet_pton(af, local_host)
bind_port = struct.pack('!H', local_port)
ret = struct.pack('!BBBB', 0x05, 0x00, 0x00, atyp) + bind_addr + bind_port
self.logger.info('Connected {0}'.format(channel))
self.request_handled = True
return ret
elif cmd == 0x02: # BIND
raise NotImplementedError('Need to implement BIND command') # TODO
elif cmd == 0x03: # UDP ASSOCIATE
raise NotImplementedError('Need to implement UDP ASSOCIATE command') # TODO
else:
raise NotImplementedError('Failed to implemented handler for cmd={0}'.format(hex(cmd)))
class SocksBase(object):
def __init__(self, transport_addr=('', 443), socks_addr=('', 1080), keepalive=None, key=None, cert=None):
self.tunnel = None # type: Tunnel
self.transport_addr = transport_addr
self.socks_addr = socks_addr
self.keepalive = keepalive
self.socks_socket = None # type: socket.socket
self.next_channel_id = integer_generator()
self.key = key
self.cert = cert
self.logger = logging.getLogger(self.__class__.__name__)
def check_socks_protocol(self, c, data):
return False
def monitor_sockets(self):
while True:
# Check tunnel and peer connections
sockets = [x for x in self.tunnel.channels if x.connected] + [self.tunnel.transport_socket]
if self.socks_socket is not None:
sockets.append(self.socks_socket)
try:
r, _, _ = select.select(sockets, [], [], self.keepalive)
except select.error:
continue
if not r:
msg = Message(mtype=MTYPE_NOOP) # timeout, send keepalive
self.tunnel.send_message(msg)
continue
if self.tunnel.transport_socket in r:
try:
msg, data = self.tunnel.recv_message()
except Exception as e:
self.logger.critical('Error receiving messages, exiting')
self.logger.debug('Error message: {0}'.format(e))
self.tunnel.transport_socket.close()
return
if msg.mtype == MTYPE_NOOP:
self.logger.debug('Received keepalive message, discarding')
elif msg.mtype == MTYPE_COPEN:
c = self.tunnel.open_channel(msg.channel)
self.logger.debug('Received OpenChannel message, opened channel: {0}'.format(c))
elif msg.mtype == MTYPE_CCLO:
try:
c = self.tunnel.get_channel_by_id(msg.channel)
self.tunnel.close_channel(msg.channel)
except KeyError:
pass
else:
self.logger.info('Closed a channel: {0}'.format(c))
elif msg.mtype == MTYPE_CADDR:
try:
c = self.tunnel.get_channel_by_id(msg.channel)
except KeyError:
pass
else:
c.remote_peer_addr = data
self.logger.info('Channel connected remotely: {0}'.format(c))
elif msg.mtype == MTYPE_DATA:
try:
c = self.tunnel.get_channel_by_id(msg.channel)
except KeyError:
pass
else:
self.logger.debug('Received {0} bytes from tunnel for {1}'.format(len(data), c))
if not self.check_socks_protocol(c, data):
try:
c.socket.sendall(data)
except:
self.logger.debug('Problem sending data to channel {0}'.format(c))
self.tunnel.close_channel(msg.channel, remote=True)
else:
self.logger.warning('Received message of unknown type {0}'.format(hex(msg.mtype)))
continue
if self.socks_socket is not None and self.socks_socket in r:
s, addr = self.socks_socket.accept()
addr = '{0}:{1}'.format(*addr)
c = self.tunnel.open_channel(self.next_channel_id.next(), remote=True)
c.local_peer_addr = addr
c.socket = s
self.logger.info('Created new channel: {0}'.format(c))
continue
for c in r:
try:
data = c.socket.recv(1024)
except Exception as e:
self.logger.debug('Problem recving from {0}: {1}'.format(c, e))
self.tunnel.close_channel(c.channel_id, remote=True)
break
if not data:
self.logger.debug('Received EOF from local socket, closing channel')
self.tunnel.close_channel(c.channel_id, remote=True)
msg = Message(mtype=MTYPE_DATA, channel=c.channel_id)
self.tunnel.send_message(msg, data=data)
self.logger.debug('Sent {0} bytes over tunnel: {1}'.format(len(data), msg))
def run(self):
raise NotImplementedError('Subclasses should implement the run() method')
class SocksRelay(SocksBase):
def check_socks_protocol(self, c, data):
if not c.socks_handler.auth_handled:
res = c.socks_handler.handle(c, data)
if not c.socks_handler.auth_handled:
self.logger.warning('SOCKS auth handler failed, expect channel close for {0}'.format(c))
msg = Message(mtype=MTYPE_DATA, channel=c.channel_id)
self.tunnel.send_message(msg, data=res)
return True
elif not c.socks_handler.request_handled:
res = c.socks_handler.handle(c, data)
msg = Message(mtype=MTYPE_DATA, channel=c.channel_id)
self.tunnel.send_message(msg, data=res)
if not c.socks_handler.request_handled:
self.logger.warning('SOCKS req handler failed, expect channel close for {0}'.format(c))
else:
msg = Message(mtype=MTYPE_CADDR, channel=c.channel_id)
self.tunnel.send_message(msg, data=c.local_peer_addr)
return True
else:
return False
def run(self):
s = socket.socket()
s = ssl.wrap_socket(s)
self.logger.debug('Connecting to {0}:{1}'.format(*self.transport_addr))
try:
s.connect(self.transport_addr)
except Exception as e:
self.logger.error('Problem connecting to server: {0}'.format(e))
else:
self.logger.info('Connected to {0}:{1}'.format(*self.transport_addr))
self.tunnel = Tunnel(s)
self.monitor_sockets()
self.logger.warning('SOCKS relay is exiting')
def relay_main(tunnel_addr=''):
tunnel_addr = (tunnel_addr.split(':')[0], int(tunnel_addr.split(':')[1]))
relay = SocksRelay(transport_addr=tunnel_addr)
relay.run()
return
relay_main(tunnel_addr='IPADDRESS:443')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment