Skip to content

Instantly share code, notes, and snippets.

@infra-0-0
Last active September 30, 2020 14:47
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save infra-0-0/7bf157202c17681d62667a006eb15abd to your computer and use it in GitHub Desktop.
Save infra-0-0/7bf157202c17681d62667a006eb15abd to your computer and use it in GitHub Desktop.
python3.5 asyncio dns resolver, based on code from who knows where
"""
The most simple DNS client written for Python with asyncio:
* Only A record is support (no CNAME, no AAAA, no MX, etc.)
* Almost no error handling
* Doesn't support fragmented UDP packets (is it possible?)
"""
import asyncio
import logging
import random
import socket
import struct
import sys
DNS_SERVER = '8.8.8.8'
DNS_PORT = 53
HEADER_LEN = 12
def encode_header(packet_id):
qr = 0 # is a request (or answer)? (no: request), 1 bit
opcode = 0 # standard query, 4 bit
aa = 0 # authorative answer? (no), 1 bit
tc = 0 # truncated? (no), 1 bit
rc = 1 # recursive? (yes), 1 bit
#flags1 = qr | opcode << 1 | aa << 5 | tc << 6 | rc << 7
flags1 = rc | tc << 1 | aa << 2 | opcode | qr << 7
ra = 0 # recursivite allowed? (no), 1 bit
z = 0 # reserved to future use (0), 3 bits
rcode = 0 # reply type (0: no error), 4 bits
#flags2 = ra | z << 1 | rcode << 4
flags2 = rcode | z << 6 | ra << 7
qdcount = 1 # number of questions
ancount = 0 # number of replies
nscount = 0 # number of authorities
arcount = 0 # number of additionals
header = struct.pack('!HBBHHHH', packet_id, flags1, flags2, qdcount, ancount, nscount, arcount)
assert len(header) == HEADER_LEN
return header
def decode_header(header):
parts = struct.unpack('!HBBHHHH', header)
packet_id, flags1, flags2, qdcount, ancount, nscount, arcount = parts
#flags1 = rc | tc << 1 | aa << 2 | opcode | qr << 7
#assert qr == 1 # is a request (or answer)? (no: request), 1 bit
#opcode = 0 # standard query, 4 bit
#aa = 0 # authorative answer? (no), 1 bit
#tc = 0 # truncated? (no), 1 bit
#rc = 1 # recursive? (yes), 1 bit
# flags2 = rcode | z << 6 | ra << 7
# ra = 0 # recursivite allowed? (no), 1 bit
# z = 0 # reserved to future use (0), 3 bits
# rcode = 0 # reply type (0: no error), 4 bits
assert nscount == 0 # number of authorities
assert arcount == 0 # number of additionals
return (packet_id, qdcount, ancount)
def decode_query(data, offset):
labels = []
while True:
part_len = data[offset]
offset += 1
if not part_len:
break
label = data[offset:offset+part_len]
offset += part_len
labels.append(label)
# FIXME: support unicode
labels = [label.decode('utf-8') for label in labels]
host = '.'.join(labels)
query_type, query_class = struct.unpack('!HH', data[offset:offset+4])
offset += 4
assert query_class == 1
return host, query_type, offset
def decode_cname(data):
data_len = data[0]
subdomain = data[1:1+data_len]
# FIXME: support unicode
subdomain = subdomain.decode('utf-8')
name = data[1+data_len:]
assert len(name) == 2, repr(name)
name = struct.unpack('!H', name)
return (subdomain, name)
def decode_reply(data, offset):
name, reply_type, reply_class, ttl, data_len = struct.unpack("!HHHIH", data[offset:offset+12])
offset += 12
assert reply_class == 1
reply = data[offset:offset+data_len]
offset += data_len
if reply_type == 1:
reply = socket.inet_ntop(socket.AF_INET, reply)
elif reply_type == 5:
reply = decode_cname(reply)
reply = (name, reply_type, ttl, reply)
return reply, offset
def encode_query(host):
# FIXME: support unicode
parts = [part.encode('ascii') for part in host.split('.')]
parts.append(b'')
query = []
for part in parts:
query.append(struct.pack('!B', len(part)))
query.append(part)
query_type = 1 # A record
query_class = 1 # IN
query.append(struct.pack('!HH', query_type, query_class))
return b''.join(query)
class UdpDnsClientProtocol(asyncio.Protocol):
def __init__(self):
self.fut = asyncio.Future()
self.transport = asyncio.Future()
self.packet_id = random.randint(0, 2**16-1)
def connection_made(self, transport):
self.transport.set_result(transport)
async def get_transport(self):
return await self.transport
async def resolve(self, host):
# try nice well defined async resolve
try:
self.fut = asyncio.Future()
header = encode_header(self.packet_id)
query = encode_query(host)
packet = header + query
t = await self.get_transport()
t.sendto(packet)
ip = await asyncio.wait_for(self.fut, 0.5)
return ip
# fall back to getaddrinfo, and finally 0.0.0.0
except:
try:
res = await loop.getaddrinfo(host, 0)
return res[-1][-1][0]
except:
return '0.0.0.0'
def datagram_received(self, data, addr):
result = 'done'
try:
(packet_id, qdcount, ancount) = decode_header(data[:HEADER_LEN])
if packet_id != self.packet_id:
self.error("invalid packet id")
return
offset = HEADER_LEN
for i in range(qdcount):
host, query_type, offset = decode_query(data, offset)
print("QUERY #%s: %r, %s" % (i, host, query_type))
for i in range(ancount):
reply, offset = decode_reply(data, offset)
name, reply_type, ttl, reply = reply
if reply_type == 1:
result = reply
print("REPLY #%s: %r" % (i, reply))
except Exception as exc:
if not self.fut.cancelled():
self.fut.set_exception(exc)
return
if not self.fut.cancelled() and not self.fut.done():
self.fut.set_result(result)
def error_received(self, exc):
print("ERROR RECEIVED", exc)
self.fut.set_exception(exc)
# TCP only
def connection_lost(self, exc):
if exc:
print("CONNECTION LOST", exc)
if not self.fut.done():
self.fut.set_exception(exc)
async def build_resolver():
proto = UdpDnsClientProtocol()
conn = await loop.create_datagram_endpoint(lambda: proto, remote_addr=(DNS_SERVER, DNS_PORT))
return proto.resolve
if __name__ == "__main__":
import time
loop = asyncio.get_event_loop()
async def test():
top1m = [x.strip().split(',')[-1] for x in open('top-1m.csv').read().split('\n') if ',' in x]
resolver = await build_resolver()
for h in top1m:
start = time.time()
res = await resolver(h)
print(h, res)
end = time.time()
print('DUR', end-start)
sys.stdout.flush()
#loop.set_debug(True)
#logging.basicConfig(level=logging.DEBUG)
loop.run_until_complete(test())
loop.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment