Created
March 14, 2022 06:53
-
-
Save muxueqz/d86d075ab72182e487db0771b4f47b7e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from typing import Optional, Tuple, List | |
import struct | |
import asyncio | |
import pathlib | |
import socket | |
import ipaddress | |
from os import environ as env | |
from collections import deque | |
from typing import Optional, Tuple | |
import sys | |
try: | |
import uvloop # type: ignore | |
uvloop.install() | |
except: | |
pass | |
def parse_query(data: bytes) -> Tuple[int, List[bytes]]: | |
header = data[:12] | |
payload = data[12:] | |
( | |
transaction_id, | |
flags, | |
num_queries, | |
num_answers, | |
num_authority, | |
num_additional, | |
) = struct.unpack(">6H", header) | |
queries: List[bytes] = [] | |
for i in range(num_queries): | |
res = payload.index(0) + 5 | |
queries.append(payload[:res]) | |
payload = payload[res:] | |
return transaction_id, queries | |
def get_domain(query: bytes) -> str: | |
parts = [] | |
while True: | |
length = query[0] | |
query = query[1:] | |
if length == 0: | |
break | |
parts.append(query[:length]) | |
query = query[length:] | |
return ".".join(x.decode("ascii") for x in parts) | |
def build_answer( | |
trans_id: int, | |
queries: List[bytes], | |
answer: Optional[bytes] = None, | |
ttl: int = 128, | |
) -> bytes: | |
flags = 0 | |
flags |= 0x8000 | |
flags |= 0x0400 | |
if not answer: | |
flags |= 0x0003 # NXDOMAIN | |
header = struct.pack(">6H", trans_id, flags, len(queries), 1 if answer else 0, 0, 0) | |
payload = b"".join(queries) | |
if answer: | |
payload += ( | |
b"\xc0\x0c" | |
+ struct.pack(">1H1H1L1H", 1, 1, ttl, 4) # IN # A # TTL # payload length | |
+ answer | |
) | |
return header + payload | |
def get_default_resolver(resolv_conf: str = "/etc/resolv.conf") -> str: | |
rc = pathlib.Path(resolv_conf) | |
if rc.is_file(): | |
with rc.open() as file: | |
while True: | |
line = file.readline() | |
if not line: | |
break | |
parsed = line.strip().split("#", 1)[0].split() | |
if len(parsed) == 2 and parsed[0] == "nameserver": | |
return parsed[1] | |
return "8.8.8.8" | |
class DNSForward(asyncio.DatagramProtocol): | |
def __init__(self, message: bytes, on_exit: asyncio.Future): | |
self.message = message | |
self.on_exit = on_exit | |
self.result: Optional[bytes] = None | |
self.transport: asyncio.DatagramTransport = asyncio.DatagramTransport() | |
def datagram_received(self, data: Optional[bytes], addr: Tuple[str, int]): | |
self.result = data | |
if self.transport: | |
self.transport.close() | |
def connection_made(self, transport: asyncio.DatagramTransport): # type: ignore[override] | |
self.transport = transport | |
self.transport.sendto(self.message) | |
def error_received(self, exc): | |
pass | |
def connection_lost(self, exc): | |
self.on_exit.set_result(self.result) | |
# DB = RedisPool(env.get("REDIS", "127.0.0.1"), db=int(env.get("REDIS_DB", 0))) | |
DB = {} | |
hosts_file = sys.argv[1] | |
with open(hosts_file, "r") as fd: | |
for line in fd: | |
record = line.split("#")[0].split() | |
if len(record) > 1: | |
ip = record[0] | |
print(record) | |
for name in record[1:]: | |
print(name) | |
DB[name] = ip | |
print(DB) | |
DNS_RELAY = (env.get("DNS_RELAY", get_default_resolver()), 53) | |
HOST = env.get("BIND", "0.0.0.0") | |
class DNSServer: | |
def __init__(self, loop=asyncio.get_event_loop()): | |
self.loop = loop | |
self.sock = None | |
self.event = asyncio.Event() | |
self.queue = deque() | |
async def on_data_received(self, data: bytes, addr: Tuple[str, int]): | |
trans_id, queries = parse_query(data) | |
for q in queries: | |
domain = get_domain(q) | |
print(trans_id, domain) | |
res = DB.get(domain) | |
if not res and "." in domain: | |
on_exit = self.loop.create_future() | |
transport, _ = await self.loop.create_datagram_endpoint( | |
lambda: DNSForward(data, on_exit), remote_addr=DNS_RELAY | |
) | |
try: | |
self.send(await on_exit, addr) | |
finally: | |
transport.close() | |
return | |
ip = ipaddress.IPv4Address(res).packed if res else None | |
self.send(build_answer(trans_id, queries, answer=ip), addr) | |
def run(self, host: str = "0.0.0.0", port: int = 53): | |
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) | |
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
self.sock.setblocking(False) | |
self.sock.bind((host, port)) | |
asyncio.ensure_future(self.recv_periodically(), loop=self.loop) | |
asyncio.ensure_future(self.send_periodically(), loop=self.loop) | |
def sock_recv( | |
self, fut: Optional[asyncio.Future] = None, registered: bool = False | |
) -> Optional[asyncio.Future]: | |
fd = self.sock.fileno() | |
if not fut: | |
fut = self.loop.create_future() | |
if fut: | |
if registered: | |
self.loop.remove_reader(fd) | |
try: | |
data, addr = self.sock.recvfrom(2048) | |
except (BlockingIOError, InterruptedError): | |
self.loop.add_reader(fd, self.sock_recv, fut, True) | |
except Exception as ex: | |
print(ex) | |
fut.set_result(0) | |
else: | |
fut.set_result((data, addr)) | |
return fut | |
async def recv_periodically(self): | |
while True: | |
data, addr = await self.sock_recv() | |
asyncio.ensure_future(self.on_data_received(data, addr), loop=self.loop) | |
def send(self, data: bytes, addr: Tuple[str, int]): | |
self.queue.append((data, addr)) | |
self.event.set() | |
def sock_send( | |
self, | |
data: bytes, | |
addr: Tuple[str, int], | |
fut: Optional[asyncio.Future] = None, | |
registered: bool = False, | |
) -> Optional[asyncio.Future]: | |
fd = self.sock.fileno() | |
if not fut: | |
fut = self.loop.create_future() | |
if fut: | |
if registered: | |
self.loop.remove_writer(fd) | |
try: | |
sent = self.sock.sendto(data, addr) | |
except (BlockingIOError, InterruptedError): | |
self.loop.add_writer(fd, self.sock_send, data, addr, fut, True) | |
except Exception as ex: | |
print(ex) | |
fut.set_result(0) | |
else: | |
fut.set_result(sent) | |
return fut | |
async def send_periodically(self): | |
while True: | |
await self.event.wait() | |
try: | |
while self.queue: | |
data, addr = self.queue.popleft() | |
_ = await self.sock_send(data, addr) | |
finally: | |
self.event.clear() | |
async def main(loop): | |
dns = DNSServer(loop) | |
dns.run(host=HOST, port=53) | |
if __name__ == "__main__": | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(main(loop)) | |
loop.run_forever() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment