Skip to content

Instantly share code, notes, and snippets.

@totekuh
Last active April 11, 2024 20:05
Show Gist options
  • Save totekuh/1cbf7c1bf5379635a62b99ebde52e2d3 to your computer and use it in GitHub Desktop.
Save totekuh/1cbf7c1bf5379635a62b99ebde52e2d3 to your computer and use it in GitHub Desktop.
A Python DNS Proxy that forwards UDP DNS queries to a TCP DNS server, with support for domain-specific resolver routing.
#!/usr/bin/env python3
import socket
import sys
import os
from threading import Thread
# domain resolver file format example:
# *.openai.com=tcp://1.1.1.1:53
# *=tcp://8.8.8.8:53
class DomainResolverEntry:
def __init__(self,
domain_name: str,
resolver_proto: str,
resolver_ip: str,
resolver_port: int):
self.domain_name = domain_name
if not resolver_proto == 'tcp' and not resolver_proto == "udp":
print("Error: Resolver protocol might be either 'tcp://' or 'udp://'")
sys.exit(1)
self.resolver_proto = resolver_proto
self.resolver_ip = resolver_ip
self.resolver_port = resolver_port
class DnsProxy:
@staticmethod
def decode_dns_query(data):
# Skip the header (12 bytes) and get to the Question section
question_section = data[12:]
domain_name = ''
i = 0
while i < len(question_section):
length = question_section[i]
if length == 0:
# End of the domain name
break
# Move past the length byte
i += 1
# Extract the label and add it to the domain name, with dot separation
domain_name += question_section[i:i + length].decode('ascii') + '.'
# Move past the current label
i += length
# The domain name ends with a dot, so strip it off
domain_name = domain_name.rstrip('.')
return domain_name
def __init__(self, listen_udp_ip: str,
listen_udp_port: int,
dest_tcp_ip: str,
dest_tcp_port: int,
domains_resolvers_list: str = None,
verbose: bool = False):
self.listen_udp_ip = listen_udp_ip
self.listen_udp_port = listen_udp_port
self.dest_tcp_ip = dest_tcp_ip
self.dest_tcp_port = dest_tcp_port
self.verbose = verbose
self.domains_resolver_list = []
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if domains_resolvers_list:
if not os.path.exists(domains_resolvers_list):
print("Error: domains resolver list file does not exist")
sys.exit(1)
else:
with open(domains_resolvers_list, 'r') as f:
for line in [line.strip() for line in f.readlines() if line.strip()]:
if line.startswith("#"):
continue
domain_name, proto_addr = line.strip().split("=")
proto, addr = proto_addr.split("://")
ip, port = addr.split(":")
self.domains_resolver_list.append(
DomainResolverEntry(domain_name=domain_name,
resolver_proto=proto,
resolver_ip=ip,
resolver_port=int(port)))
def find_resolver_for_domain(self, domain_name: str):
# Sort resolver entries to ensure direct matches are checked before wildcard matches
sorted_resolvers = sorted(self.domains_resolver_list, key=lambda entry: '*' in entry.domain_name)
for entry in sorted_resolvers:
entry_domain = entry.domain_name
if "*" == entry.domain_name:
return entry
if self.matches_domain(domain_name, entry_domain):
return entry
def matches_domain(self, domain_name: str, pattern: str):
if pattern.startswith('*'):
# Match any subdomain or the domain itself (e.g., *.company.de matches sub.company.de and company.de)
suffix = pattern[2:] # Remove '*.' prefix
return domain_name == suffix or domain_name.endswith('.' + suffix)
else:
# Exact match or subdomain match without wildcard
return domain_name == pattern or domain_name.endswith('.' + pattern)
def proxy_dns_query(self, client_address: tuple, data: bytes):
domain_name = DnsProxy.decode_dns_query(data)
resolver_entry = self.find_resolver_for_domain(domain_name)
if resolver_entry:
dest_ip = resolver_entry.resolver_ip
dest_port = resolver_entry.resolver_port
else:
# Fallback to default DNS server provided
dest_ip = self.dest_tcp_ip
dest_port = self.dest_tcp_port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket:
try:
if self.verbose:
print(f"Proxying DNS query {domain_name} to {dest_ip}:{dest_port}")
tcp_socket.connect((dest_ip, dest_port))
tcp_data = len(data).to_bytes(2, byteorder='big') + data
tcp_socket.sendall(tcp_data)
# Receive the response from the TCP DNS server
response_length = int.from_bytes(tcp_socket.recv(2), byteorder='big')
response = tcp_socket.recv(response_length)
# Send the response back to the client over UDP
self.udp_socket.sendto(response, client_address)
except OSError as e:
if e.errno == 98: # Address already in use
print(f"Error: Port {self.listen_udp_port} on {self.listen_udp_ip} "
f"is already in use by another application.")
elif e.errno == 111: # Connection refused
print("Error: Failed to connect to DNS TCP server; connection refused.")
else:
print(f"Unhandled OSError: {e}")
except Exception as e:
print(f"Error: {e}")
def start_udp_server(self):
try:
self.udp_socket.bind((self.listen_udp_ip, self.listen_udp_port))
print(
f"UDP DNS Proxy listening on {self.listen_udp_ip}:{self.listen_udp_port}, "
f"forwarding to TCP DNS at {self.dest_tcp_ip}:{self.dest_tcp_port} via SSH tunnel")
while True:
data, client_address = self.udp_socket.recvfrom(512) # 512 bytes is typically enough for DNS queries
Thread(target=self.proxy_dns_query,
args=(client_address, data)).start()
except OSError as e:
if e.errno == 98: # Address already in use
print(f"Error: Port {self.listen_udp_port} on {self.listen_udp_ip} "
f"is already in use by another application.")
else:
print(f"Unhandled OSError: {e}")
except KeyboardInterrupt:
print()
print("Interrupted")
sys.exit()
except Exception as e:
print(f"An unexpected error occurred: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment