Skip to content

Instantly share code, notes, and snippets.

@itdaniher
Forked from infra-0-0/dns.py
Created January 23, 2018 19:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save itdaniher/20ee0511aeb8349f06b78527dc0ed84e to your computer and use it in GitHub Desktop.
Save itdaniher/20ee0511aeb8349f06b78527dc0ed84e to your computer and use it in GitHub Desktop.
python3.5 asyncio dns resolver, based on code from who knows where
import os
import socket
import struct
import asyncio
import logging
DNS_SERVER = '8.8.8.8'
DNS_PORT = 53
def encode_header(packet_id):
return struct.pack('!HBBHHHH', packet_id, 1, 0, 1, 0, 0, 0)
def decode_header(header):
packet_id, flags1, flags2, qdcount, ancount, nscount, arcount = struct.unpack('!HBBHHHH', header)
if nscount != 0: # number of authorities
raise Exception('wrong nscount')
if arcount != 0: # number of additionals
raise Exception('wrong arcount')
return (packet_id, qdcount, ancount)
def decode_query(data, offset):
labels = []
while True:
part_len = data[offset]
offset += 1
if not part_len:
break
label = data[offset:offset+part_len]
offset += part_len
labels.append(label)
labels = [label.decode('utf-8') for label in labels]
host = '.'.join(labels)
query_type, query_class = struct.unpack('!HH', data[offset:offset+4])
offset += 4
if query_class != 1:
raise Exception('wrong query class')
return host, query_type, offset
def decode_cname(data):
data_len = data[0]
subdomain = data[1:1+data_len]
subdomain = subdomain.decode('utf-8')
name = data[1+data_len:]
if len(name) != 2:
raise Exception('wrong length for name')
name = struct.unpack('!H', name)
return (subdomain, name)
def decode_reply(data, offset):
name, reply_type, reply_class, ttl, data_len = struct.unpack("!HHHIH", data[offset:offset+12])
offset += 12
if reply_class != 1:
raise Exception('wrong reply class')
reply = data[offset:offset+data_len]
offset += data_len
if reply_type == 1:
reply = socket.inet_ntop(socket.AF_INET, reply)
elif reply_type == 5:
reply = decode_cname(reply)
reply = (name, reply_type, ttl, reply)
return reply, offset
def encode_query(host):
return b''.join( [len(part).to_bytes(1, 'little')+part.encode('utf-8') for part in host.split('.')] + [b'\x00'] + [struct.pack('!HH', 1, 1)])
class UdpDnsClientProtocol(asyncio.Protocol):
def __init__(self):
self.fut = asyncio.Future()
self.transport = asyncio.Future()
self.packet_id = int.from_bytes(os.urandom(2), 'little')
def connection_made(self, transport):
self.transport.set_result(transport)
async def get_transport(self):
return await self.transport
async def resolve(self, host):
self.fut = asyncio.Future()
header = encode_header(self.packet_id)
query = encode_query(host)
packet = header + query
t = await self.get_transport()
t.sendto(packet)
ip = await asyncio.wait_for(self.fut, 1.0)
return ip
def datagram_received(self, data, addr):
HEADER_LEN = 12
result = 'done'
try:
(packet_id, qdcount, ancount) = decode_header(data[:HEADER_LEN])
if packet_id != self.packet_id:
self.error("invalid packet id")
return
offset = HEADER_LEN
for i in range(qdcount):
host, query_type, offset = decode_query(data, offset)
logging.debug("QUERY #%s: %r, %s" % (i, host, query_type))
for i in range(ancount):
reply, offset = decode_reply(data, offset)
name, reply_type, ttl, reply = reply
if reply_type == 1:
result = reply
logging.debug("REPLY #%s: %r" % (i, reply))
except Exception as exc:
if not self.fut.cancelled():
self.fut.set_exception(exc)
return
if not self.fut.cancelled() and not self.fut.done():
self.fut.set_result(result)
def error_received(self, exc):
logging.info("ERROR RECEIVED %s" % exc)
self.fut.set_exception(exc)
async def build_resolver(loop = None):
loop = loop or asyncio.get_event_loop()
proto = UdpDnsClientProtocol()
conn = await loop.create_datagram_endpoint(lambda: proto, remote_addr=(DNS_SERVER, DNS_PORT))
return proto.resolve
async def lookup(target, resolver = None, loop = None):
resolver = resolver or await build_resolver(loop)
return await resolver(target)
if __name__ == "__main__":
import sys
loop = asyncio.get_event_loop()
if '--debug' in sys.argv:
loop.set_debug(True)
logging.basicConfig(level=logging.DEBUG)
print(loop.run_until_complete(lookup(sys.argv[1])))
loop.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment