Skip to content

Instantly share code, notes, and snippets.

@adiroiban
Created July 6, 2021 19:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adiroiban/b13989e2bdb97be6b5d7655bff38101a to your computer and use it in GitHub Desktop.
Save adiroiban/b13989e2bdb97be6b5d7655bff38101a to your computer and use it in GitHub Desktop.
ProxyProtocol v2 wrapper
"""
Proxy protocol support.
Code based on https://github.com/icgood/proxy-protocol
http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
"""
from __future__ import unicode_literals
import socket
import struct
from ipaddress import IPv4Address, IPv6Address
from twisted.internet import address
from chevah.server import force_unicode
from chevah.server.commons.exception import ServerException
class ProxyProtocolResult(object):
"""
The values parsed from the proxy protocol payload.
"""
def __init__(self, source, destination, family, protocol=None):
self._source = source
self._destination = destination
self._family = family
self._protocol = protocol
@property
def proxied(self):
return True
@property
def source(self):
return self._source
@property
def dest(self):
return self._destination
@property
def family(self):
return self._family
@property
def protocol(self):
return self._protocol
_commands = [(0x00, 'local'),
(0x01, 'proxy')]
_families = [(0x00, socket.AF_UNSPEC),
(0x10, socket.AF_INET),
(0x20, socket.AF_INET6),
]
_protocols = [(0x00, None),
(0x01, socket.SOCK_STREAM),
(0x02, socket.SOCK_DGRAM)]
_commands_l = {left: right for left, right in _commands}
_commands_r = {right: left for left, right in _commands}
_families_l = {left: right for left, right in _families}
_families_r = {right: left for left, right in _families}
_protocols_l = {left: right for left, right in _protocols}
_protocols_r = {right: left for left, right in _protocols}
def _parse_connection(data):
"""
Return a tuple with addressed from the Proxy Protocol 2 new connection
data and the remaining payload data.
Return `None` if the request is not proxied.
"""
if not data.startswith(b'\r\n\r\n\x00\r\nQUIT\n'):
return None, data
if ord(data[12]) & 0xf0 != 0x20:
raise ServerException(
'Only proxy protocol version 2 is supported.')
byte_12, byte_13, addr_len = struct.unpack('!BBH', data[12:16])
command = _commands_l.get(byte_12 & 0x0f)
family = _families_l.get(byte_13 & 0xf0)
protocol = _protocols_l.get(byte_13 & 0x0f)
if command == 'local':
# A local connection from the proxy.
return None, data[16:]
if command != 'proxy':
raise ServerException('Only the "proxy" command is supported.')
address_data = data[16:16 + addr_len]
result = None
if family == socket.AF_INET:
source_ip, dest_ip, source_port, dest_port = \
struct.unpack('!4s4sHH', address_data)
source_addr4 = (IPv4Address(source_ip), source_port)
dest_addr4 = (IPv4Address(dest_ip), dest_port)
result = ProxyProtocolResult(
source_addr4, dest_addr4,
family=socket.AF_INET,
protocol=protocol,
)
if family == socket.AF_INET6:
source_ip, dest_ip, source_port, dest_port = \
struct.unpack('!16s16sHH', address_data)
source_addr6 = (IPv6Address(source_ip), source_port)
dest_addr6 = (IPv6Address(dest_ip), dest_port)
result = ProxyProtocolResult(
source_addr6, dest_addr6,
family=socket.AF_INET,
protocol=protocol,
)
if not result:
raise ServerException('Unknown proxied connection family.')
return result, data[16 + addr_len:]
def buildV2(source, destination, family, protocol=None):
"""
Build the whole proxy protocol v2 message.
"""
addresses = None
if family == socket.AF_INET:
source_ip = IPv4Address(source[0]).packed
source_port = source[1]
dest_ip = IPv4Address(destination[0]).packed
dest_port = destination[1]
addresses = struct.pack(
'!4s4sHH', source_ip, dest_ip, source_port, dest_port)
if family == socket.AF_INET6:
source_ip = IPv6Address(source[0]).packed
source_port = source[1]
dest_ip = IPv6Address(destination[0]).packed
dest_port = destination[1]
addresses = struct.pack(
'!16s16sHH', source_ip, dest_ip, source_port, dest_port)
if not addresses:
raise ServerException('Address family is not supported.')
header = _build_header(
addresses, family=family, protocol=protocol)
return header + addresses
def buildV2local(family):
"""
Build the proxy protocol v2 header for a local connection.
"""
return _build_header([], family=family, proxied=False)
def _build_header(addresses, family, protocol=None, proxied=True):
"""Builds the 16-byte block that begins every PROXY protocol v2 header.
Args:
addresses: The addresses block, as returned by
:meth:`.build_addresses`.
family: The original socket family.
protocol: The original socket protocol.
proxied: True if the connection should not be considered proxied.
"""
byte_12 = 0x20 + _commands_r['proxy' if proxied else 'local']
byte_13 = _families_r[family] + _protocols_r[protocol]
return b'\r\n\r\n\x00\r\nQUIT\n' + struct.pack(
'!BBH', byte_12, byte_13, len(addresses))
# Map from protocol version to Twisted address class.
_TWISTED_MAPPING = {
4: address.IPv4Address,
6: address.IPv6Address,
}
def proxify_v2(target_protocol):
"""
Make the protocol handle Proxy Protocol v2 connections.
"""
protocol_receive = target_protocol.dataReceived
def _pp2Receive(data):
"""
Called when initial connection data is received.
It assumes all the Proxy Protocol data is received in the same chunk.
"""
addresses, payload = _parse_connection(data)
if not addresses:
# Not a proxied connection or a local proxy.
target_protocol.dataReceived = protocol_receive
if payload:
protocol_receive(payload)
return
if hasattr(target_protocol, 'wrappedProtocol'):
if hasattr(target_protocol.wrappedProtocol, 'wrappedProtocol'):
# SSL connection
protocol = target_protocol.wrappedProtocol.wrappedProtocol
else:
# FTPS data connection.
protocol = target_protocol.wrappedProtocol
else:
# No SSL or TLS connection.
protocol = target_protocol
initial_peer = protocol.avatar.peer
initial_source = initial_peer.host
initial_port = initial_peer.port
source = addresses.source
twisted_address_class = _TWISTED_MAPPING[source[0].version]
protocol.avatar._peer = twisted_address_class(
host=source[0].compressed.encode('ascii'),
port=source[1],
type=initial_peer.type,
)
protocol.emitEvent('20003', data={
'host': force_unicode(initial_source),
'port': initial_port,
})
target_protocol.dataReceived = protocol_receive
if not payload:
# Don't write empty data as this messes up the TLS/SSL wrapper.
return
return protocol_receive(payload)
target_protocol.dataReceived = _pp2Receive
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment