Skip to content

Instantly share code, notes, and snippets.

@TakesxiSximada
Last active December 25, 2015 13:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TakesxiSximada/802d36068f09a3393541 to your computer and use it in GitHub Desktop.
Save TakesxiSximada/802d36068f09a3393541 to your computer and use it in GitHub Desktop.
DNSサーバのさわり書いてみた
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""Dummy Name Server
See
- http://www.ietf.org/rfc/rfc1034.txt
- http://www.ietf.org/rfc/rfc1035.txt
- http://www.ietf.org/rfc/rfc1886.txt
"""
import re
import sys
import argparse
from functools import partial
from socketserver import (
ThreadingUDPServer,
BaseRequestHandler,
)
import enum
from binarize import (
Structure,
BYTES,
UINT16,
UINT32,
)
TTL = 9
NAME_IPADDR = {
(b'test', b''): b'\x7f\x00\x00\x01', # 127.0.0.1
}
def display_buffer(name, buf):
print('#' * 10, end=' ')
print(name, end=' ')
print('#' * 10)
for bch in buf:
print('{:02X}'.format(bch), end=' ')
print()
print('#' * 20)
@enum.unique
class QueryResponse(enum.IntEnum):
query = 0
response = 1
@enum.unique
class OPCode(enum.IntEnum):
standard_query = 0
inverse_query = 1
server_status_request = 2
@enum.unique
class AuthoritativeAnswer(enum.IntEnum):
no = 0
yes = 1
@enum.unique
class TrunCation(enum.IntEnum):
no = 0
yes = 1
@enum.unique
class RecusionDesired(enum.IntEnum):
desired = 0
recursive = 1
@enum.unique
class RecursionAvailable(enum.IntEnum):
unable = 0
available = 1
@enum.unique
class ResponseCode(enum.IntEnum):
success = 0
format_error = 1
server_error = 2
name_error = 3
not_implemented = 4
deny = 5
class QueryType(enum.IntEnum):
A = 1 # address
NS = 2 # name server
CNAME = 5 # canonical name
PTR = 12 # pointer
MX = 15 # mail exchange
AAAA = 28 # IPv6
ANY = 255 # any
class QueryClass(enum.IntEnum):
IN = 1 # internet
CH = 2 # ???
class FlagMask(enum.IntEnum):
qr = 0b1000000000000000
opcode = 0b0111100000000000
aa = 0b0000010000000000
tc = 0b0000010000000000
rd = 0b00000000100000000
ra = 0b0000000010000000
reserve = 0b0000000001110000
rcode = 0b0000000000001111
def get_shift_count(value):
matching = re.search('(?P<zero>0+)$', bin(value))
return len(matching.group('zero')) if matching else 0
class FlagShiftCount(enum.IntEnum):
qr = get_shift_count(FlagMask.qr.value)
opcode = get_shift_count(FlagMask.opcode.value)
aa = get_shift_count(FlagMask.aa.value)
tc = get_shift_count(FlagMask.tc.value)
rd = get_shift_count(FlagMask.rd.value)
ra = get_shift_count(FlagMask.ra.value)
reserve = get_shift_count(FlagMask.reserve.value)
rcode = get_shift_count(FlagMask.rcode.value)
class Header(Structure):
identifier = UINT16()
flags = UINT16()
query_count = UINT16()
answer_count = UINT16()
authority_rr_count = UINT16()
addon_rr_count = UINT16()
def _read_flags(self, mask, shift_count):
return (self.flags & mask) >> shift_count
def _write_flags(self, value, mask, shift_count):
self.flags = (self.flags & ~mask) | ((value << shift_count) & mask)
qr = property(
partial(_read_flags, mask=FlagMask.qr.value, shift_count=FlagShiftCount.qr.value),
partial(_write_flags, mask=FlagMask.qr.value, shift_count=FlagShiftCount.qr.value),
)
opcode = property(
partial(_read_flags, mask=FlagMask.opcode.value, shift_count=FlagShiftCount.opcode.value),
partial(_write_flags, mask=FlagMask.opcode.value, shift_count=FlagShiftCount.opcode.value),
)
aa = property(
partial(_read_flags, mask=FlagMask.aa.value, shift_count=FlagShiftCount.aa.value),
partial(_write_flags, mask=FlagMask.aa.value, shift_count=FlagShiftCount.aa.value),
)
tc = property(
partial(_read_flags, mask=FlagMask.tc.value, shift_count=FlagShiftCount.tc.value),
partial(_write_flags, mask=FlagMask.tc.value, shift_count=FlagShiftCount.tc.value),
)
ra = property(
partial(_read_flags, mask=FlagMask.ra.value, shift_count=FlagShiftCount.ra.value),
partial(_write_flags, mask=FlagMask.ra.value, shift_count=FlagShiftCount.ra.value),
)
rd = property(
partial(_read_flags, mask=FlagMask.rd.value, shift_count=FlagShiftCount.rd.value),
partial(_write_flags, mask=FlagMask.rd.value, shift_count=FlagShiftCount.rd.value),
)
reserve = property(
partial(_read_flags, mask=FlagMask.reserve.value, shift_count=FlagShiftCount.reserve.value),
partial(_write_flags, mask=FlagMask.reserve.value, shift_count=FlagShiftCount.reserve.value), # noqa
)
rcode = property(
partial(_read_flags, mask=FlagMask.rcode.value, shift_count=FlagShiftCount.rcode.value),
partial(_write_flags, mask=FlagMask.rcode.value, shift_count=FlagShiftCount.rcode.value),
)
class QuestionRecord(Structure):
qname = BYTES(size=32, fill=b'')
qtype = UINT16()
qclass = UINT16()
@classmethod
def decode(cls, data):
class _AnalyzeRecord(Structure):
qtype = UINT16()
qclass = UINT16()
record = cls()
qname_end = data.index(b'\x00') + 1
record.qname = data[:qname_end]
data = data[qname_end:]
analyze_record = _AnalyzeRecord.decode(data)
record.qtype = analyze_record.qtype
record.qclass = analyze_record.qclass
return record
def get_qname(self):
return parse_qname(self.qname)
def __len__(self):
return len(self.encode())
class ResourceRecord(Structure):
name_ = BYTES(size=32, fill=b'')
type_ = UINT16()
class_ = UINT16()
ttl = UINT32()
resource_data_length = UINT16()
resource_data = BYTES(size=32, fill=b'')
def decode(cls, data):
class _AnalyzeRecord(Structure):
type_ = UINT16()
class_ = UINT16()
ttl = UINT32()
resource_data_length = UINT16()
record = cls()
name_end = data.index(b'\x00')
record.name = data[:name_end]
data = data[name_end+1:]
analyze_record = _AnalyzeRecord.decode(data[:_AnalyzeRecord.size])
record.type_ = analyze_record.type_
record.class_ = analyze_record.class_
record.ttl = analyze_record.ttl
record.resource_data_length = analyze_record.resource_data_length
data = data[_AnalyzeRecord.size:]
record.resource_data = data[:record.resource_data_length]
return record
def __len__(self):
return len(self.encode())
class ProtocolDataFactory(object):
def __init__(self, structure):
self.structure = structure
@property
def size(self):
return self.structure.size
def __call__(self, data=None):
if data is None:
data = b'\x00' * self.size
return self.structure.decode(data)
header_factory = ProtocolDataFactory(Header)
question_record_factory = ProtocolDataFactory(QuestionRecord)
resource_record_factory = ProtocolDataFactory(ResourceRecord)
def parse_qname(qname):
while qname:
base = 0
start = base + 1
end = start + qname[base]
yield qname[start:end]
qname = qname[end:]
def get_qname(*args, **kwds):
return [name for name in parse_qname(*args, **kwds)]
def parse_request(data):
header = header_factory(data[:header_factory.size])
question_records = []
data = data[header_factory.size:]
for ii in range(header.query_count):
question_record = question_record_factory(data)
question_records.append(question_record)
data = data[len(question_record):]
return header, question_records, data
class UDPNameRequestHandler(BaseRequestHandler):
def handle(self):
data = self.request[0]
conn = self.request[1]
header, question_records, data = self.parse_request(data)
buf = self.build_response(header, question_records, data)
conn.sendto(buf, self.client_address)
def build_response(self, request_header, question_records, data):
resource_records = []
for question_record in question_records:
names = get_qname(question_record.qname)
data = NAME_IPADDR.get(tuple(names)) or b''
resource_record = ResourceRecord()
resource_record.name_ = question_record.qname # b'\xc0\x01' # question_record.qname
resource_record.type_ = question_record.qtype
resource_record.class_ = question_record.qclass
resource_record.ttl = TTL
resource_record.resource_data = data
resource_record.resource_data_length = len(resource_record.resource_data)
resource_records.append(resource_record)
header = header_factory()
header.identifier = request_header.identifier
header.qr = QueryResponse.response.value
header.opcode = request_header.opcode
header.aa = AuthoritativeAnswer.yes.value
header.tc = TrunCation.no.value
header.query_count = 0 # len(question_records)
header.answer_count = len(resource_records)
header.authority_rr_count = 0
header.addon_rr_count = 0
buf = header.encode()
display_buffer('header', header.encode())
for record in resource_records:
buf += record.encode()
display_buffer('resource record', record.encode())
return buf
def parse_request(self, data):
print('*' * 40)
print(data)
header, question_records, data = parse_request(data)
print('-' * 40)
print('ID:', header.identifier)
print('QR:', header.qr)
print('OPCODE:', header.opcode)
print('AA:', header.aa)
print('TC:', header.tc)
print('RD:', header.rd)
print('RA:', header.ra)
print('Reserve:', header.reserve)
print('RCODE:', header.rcode)
print('QUERY COUNT:', header.query_count)
print('ANSWER COUNT:', header.answer_count)
print('AUTHORITY RR COUNT:', header.authority_rr_count)
print('ADDON RR COUNT:', header.addon_rr_count)
print('-' * 40)
for question_record in question_records:
print('QNAME:', list(question_record.get_qname()))
print('QTYPE:', question_record.qtype)
print('QCLASS:', question_record.qclass)
print('-' * 40)
print(data)
return header, question_records, data
# return header, resource_record_header
def main(argv=sys.argv[1:]):
parser = argparse.ArgumentParser()
parser.add_argument('--host', default='127.0.0.1')
parser.add_argument('-p', '--port', default=53)
args = parser.parse_args(argv)
host = args.host
port = args.port
server = ThreadingUDPServer((host, port), UDPNameRequestHandler)
try:
server.serve_forever()
finally:
server.shutdown()
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment