Last active
April 11, 2024 20:05
-
-
Save totekuh/1cbf7c1bf5379635a62b99ebde52e2d3 to your computer and use it in GitHub Desktop.
A Python DNS Proxy that forwards UDP DNS queries to a TCP DNS server, with support for domain-specific resolver routing.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import socket | |
import sys | |
import os | |
from threading import Thread | |
# domain resolver file format example: | |
# *.openai.com=tcp://1.1.1.1:53 | |
# *=tcp://8.8.8.8:53 | |
class DomainResolverEntry: | |
def __init__(self, | |
domain_name: str, | |
resolver_proto: str, | |
resolver_ip: str, | |
resolver_port: int): | |
self.domain_name = domain_name | |
if not resolver_proto == 'tcp' and not resolver_proto == "udp": | |
print("Error: Resolver protocol might be either 'tcp://' or 'udp://'") | |
sys.exit(1) | |
self.resolver_proto = resolver_proto | |
self.resolver_ip = resolver_ip | |
self.resolver_port = resolver_port | |
class DnsProxy: | |
@staticmethod | |
def decode_dns_query(data): | |
# Skip the header (12 bytes) and get to the Question section | |
question_section = data[12:] | |
domain_name = '' | |
i = 0 | |
while i < len(question_section): | |
length = question_section[i] | |
if length == 0: | |
# End of the domain name | |
break | |
# Move past the length byte | |
i += 1 | |
# Extract the label and add it to the domain name, with dot separation | |
domain_name += question_section[i:i + length].decode('ascii') + '.' | |
# Move past the current label | |
i += length | |
# The domain name ends with a dot, so strip it off | |
domain_name = domain_name.rstrip('.') | |
return domain_name | |
def __init__(self, listen_udp_ip: str, | |
listen_udp_port: int, | |
dest_tcp_ip: str, | |
dest_tcp_port: int, | |
domains_resolvers_list: str = None, | |
verbose: bool = False): | |
self.listen_udp_ip = listen_udp_ip | |
self.listen_udp_port = listen_udp_port | |
self.dest_tcp_ip = dest_tcp_ip | |
self.dest_tcp_port = dest_tcp_port | |
self.verbose = verbose | |
self.domains_resolver_list = [] | |
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
if domains_resolvers_list: | |
if not os.path.exists(domains_resolvers_list): | |
print("Error: domains resolver list file does not exist") | |
sys.exit(1) | |
else: | |
with open(domains_resolvers_list, 'r') as f: | |
for line in [line.strip() for line in f.readlines() if line.strip()]: | |
if line.startswith("#"): | |
continue | |
domain_name, proto_addr = line.strip().split("=") | |
proto, addr = proto_addr.split("://") | |
ip, port = addr.split(":") | |
self.domains_resolver_list.append( | |
DomainResolverEntry(domain_name=domain_name, | |
resolver_proto=proto, | |
resolver_ip=ip, | |
resolver_port=int(port))) | |
def find_resolver_for_domain(self, domain_name: str): | |
# Sort resolver entries to ensure direct matches are checked before wildcard matches | |
sorted_resolvers = sorted(self.domains_resolver_list, key=lambda entry: '*' in entry.domain_name) | |
for entry in sorted_resolvers: | |
entry_domain = entry.domain_name | |
if "*" == entry.domain_name: | |
return entry | |
if self.matches_domain(domain_name, entry_domain): | |
return entry | |
def matches_domain(self, domain_name: str, pattern: str): | |
if pattern.startswith('*'): | |
# Match any subdomain or the domain itself (e.g., *.company.de matches sub.company.de and company.de) | |
suffix = pattern[2:] # Remove '*.' prefix | |
return domain_name == suffix or domain_name.endswith('.' + suffix) | |
else: | |
# Exact match or subdomain match without wildcard | |
return domain_name == pattern or domain_name.endswith('.' + pattern) | |
def proxy_dns_query(self, client_address: tuple, data: bytes): | |
domain_name = DnsProxy.decode_dns_query(data) | |
resolver_entry = self.find_resolver_for_domain(domain_name) | |
if resolver_entry: | |
dest_ip = resolver_entry.resolver_ip | |
dest_port = resolver_entry.resolver_port | |
else: | |
# Fallback to default DNS server provided | |
dest_ip = self.dest_tcp_ip | |
dest_port = self.dest_tcp_port | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket: | |
try: | |
if self.verbose: | |
print(f"Proxying DNS query {domain_name} to {dest_ip}:{dest_port}") | |
tcp_socket.connect((dest_ip, dest_port)) | |
tcp_data = len(data).to_bytes(2, byteorder='big') + data | |
tcp_socket.sendall(tcp_data) | |
# Receive the response from the TCP DNS server | |
response_length = int.from_bytes(tcp_socket.recv(2), byteorder='big') | |
response = tcp_socket.recv(response_length) | |
# Send the response back to the client over UDP | |
self.udp_socket.sendto(response, client_address) | |
except OSError as e: | |
if e.errno == 98: # Address already in use | |
print(f"Error: Port {self.listen_udp_port} on {self.listen_udp_ip} " | |
f"is already in use by another application.") | |
elif e.errno == 111: # Connection refused | |
print("Error: Failed to connect to DNS TCP server; connection refused.") | |
else: | |
print(f"Unhandled OSError: {e}") | |
except Exception as e: | |
print(f"Error: {e}") | |
def start_udp_server(self): | |
try: | |
self.udp_socket.bind((self.listen_udp_ip, self.listen_udp_port)) | |
print( | |
f"UDP DNS Proxy listening on {self.listen_udp_ip}:{self.listen_udp_port}, " | |
f"forwarding to TCP DNS at {self.dest_tcp_ip}:{self.dest_tcp_port} via SSH tunnel") | |
while True: | |
data, client_address = self.udp_socket.recvfrom(512) # 512 bytes is typically enough for DNS queries | |
Thread(target=self.proxy_dns_query, | |
args=(client_address, data)).start() | |
except OSError as e: | |
if e.errno == 98: # Address already in use | |
print(f"Error: Port {self.listen_udp_port} on {self.listen_udp_ip} " | |
f"is already in use by another application.") | |
else: | |
print(f"Unhandled OSError: {e}") | |
except KeyboardInterrupt: | |
print() | |
print("Interrupted") | |
sys.exit() | |
except Exception as e: | |
print(f"An unexpected error occurred: {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment