Skip to content

Instantly share code, notes, and snippets.

@muxueqz
Created March 14, 2022 06:53
Show Gist options
  • Save muxueqz/d86d075ab72182e487db0771b4f47b7e to your computer and use it in GitHub Desktop.
Save muxueqz/d86d075ab72182e487db0771b4f47b7e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from typing import Optional, Tuple, List
import struct
import asyncio
import pathlib
import socket
import ipaddress
from os import environ as env
from collections import deque
from typing import Optional, Tuple
import sys
try:
import uvloop # type: ignore
uvloop.install()
except:
pass
def parse_query(data: bytes) -> Tuple[int, List[bytes]]:
header = data[:12]
payload = data[12:]
(
transaction_id,
flags,
num_queries,
num_answers,
num_authority,
num_additional,
) = struct.unpack(">6H", header)
queries: List[bytes] = []
for i in range(num_queries):
res = payload.index(0) + 5
queries.append(payload[:res])
payload = payload[res:]
return transaction_id, queries
def get_domain(query: bytes) -> str:
parts = []
while True:
length = query[0]
query = query[1:]
if length == 0:
break
parts.append(query[:length])
query = query[length:]
return ".".join(x.decode("ascii") for x in parts)
def build_answer(
trans_id: int,
queries: List[bytes],
answer: Optional[bytes] = None,
ttl: int = 128,
) -> bytes:
flags = 0
flags |= 0x8000
flags |= 0x0400
if not answer:
flags |= 0x0003 # NXDOMAIN
header = struct.pack(">6H", trans_id, flags, len(queries), 1 if answer else 0, 0, 0)
payload = b"".join(queries)
if answer:
payload += (
b"\xc0\x0c"
+ struct.pack(">1H1H1L1H", 1, 1, ttl, 4) # IN # A # TTL # payload length
+ answer
)
return header + payload
def get_default_resolver(resolv_conf: str = "/etc/resolv.conf") -> str:
rc = pathlib.Path(resolv_conf)
if rc.is_file():
with rc.open() as file:
while True:
line = file.readline()
if not line:
break
parsed = line.strip().split("#", 1)[0].split()
if len(parsed) == 2 and parsed[0] == "nameserver":
return parsed[1]
return "8.8.8.8"
class DNSForward(asyncio.DatagramProtocol):
def __init__(self, message: bytes, on_exit: asyncio.Future):
self.message = message
self.on_exit = on_exit
self.result: Optional[bytes] = None
self.transport: asyncio.DatagramTransport = asyncio.DatagramTransport()
def datagram_received(self, data: Optional[bytes], addr: Tuple[str, int]):
self.result = data
if self.transport:
self.transport.close()
def connection_made(self, transport: asyncio.DatagramTransport): # type: ignore[override]
self.transport = transport
self.transport.sendto(self.message)
def error_received(self, exc):
pass
def connection_lost(self, exc):
self.on_exit.set_result(self.result)
# DB = RedisPool(env.get("REDIS", "127.0.0.1"), db=int(env.get("REDIS_DB", 0)))
DB = {}
hosts_file = sys.argv[1]
with open(hosts_file, "r") as fd:
for line in fd:
record = line.split("#")[0].split()
if len(record) > 1:
ip = record[0]
print(record)
for name in record[1:]:
print(name)
DB[name] = ip
print(DB)
DNS_RELAY = (env.get("DNS_RELAY", get_default_resolver()), 53)
HOST = env.get("BIND", "0.0.0.0")
class DNSServer:
def __init__(self, loop=asyncio.get_event_loop()):
self.loop = loop
self.sock = None
self.event = asyncio.Event()
self.queue = deque()
async def on_data_received(self, data: bytes, addr: Tuple[str, int]):
trans_id, queries = parse_query(data)
for q in queries:
domain = get_domain(q)
print(trans_id, domain)
res = DB.get(domain)
if not res and "." in domain:
on_exit = self.loop.create_future()
transport, _ = await self.loop.create_datagram_endpoint(
lambda: DNSForward(data, on_exit), remote_addr=DNS_RELAY
)
try:
self.send(await on_exit, addr)
finally:
transport.close()
return
ip = ipaddress.IPv4Address(res).packed if res else None
self.send(build_answer(trans_id, queries, answer=ip), addr)
def run(self, host: str = "0.0.0.0", port: int = 53):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.setblocking(False)
self.sock.bind((host, port))
asyncio.ensure_future(self.recv_periodically(), loop=self.loop)
asyncio.ensure_future(self.send_periodically(), loop=self.loop)
def sock_recv(
self, fut: Optional[asyncio.Future] = None, registered: bool = False
) -> Optional[asyncio.Future]:
fd = self.sock.fileno()
if not fut:
fut = self.loop.create_future()
if fut:
if registered:
self.loop.remove_reader(fd)
try:
data, addr = self.sock.recvfrom(2048)
except (BlockingIOError, InterruptedError):
self.loop.add_reader(fd, self.sock_recv, fut, True)
except Exception as ex:
print(ex)
fut.set_result(0)
else:
fut.set_result((data, addr))
return fut
async def recv_periodically(self):
while True:
data, addr = await self.sock_recv()
asyncio.ensure_future(self.on_data_received(data, addr), loop=self.loop)
def send(self, data: bytes, addr: Tuple[str, int]):
self.queue.append((data, addr))
self.event.set()
def sock_send(
self,
data: bytes,
addr: Tuple[str, int],
fut: Optional[asyncio.Future] = None,
registered: bool = False,
) -> Optional[asyncio.Future]:
fd = self.sock.fileno()
if not fut:
fut = self.loop.create_future()
if fut:
if registered:
self.loop.remove_writer(fd)
try:
sent = self.sock.sendto(data, addr)
except (BlockingIOError, InterruptedError):
self.loop.add_writer(fd, self.sock_send, data, addr, fut, True)
except Exception as ex:
print(ex)
fut.set_result(0)
else:
fut.set_result(sent)
return fut
async def send_periodically(self):
while True:
await self.event.wait()
try:
while self.queue:
data, addr = self.queue.popleft()
_ = await self.sock_send(data, addr)
finally:
self.event.clear()
async def main(loop):
dns = DNSServer(loop)
dns.run(host=HOST, port=53)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main(loop))
loop.run_forever()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment