Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Snawoot/3de21bc5acc5c9feda28b6ac7924cde9 to your computer and use it in GitHub Desktop.
Save Snawoot/3de21bc5acc5c9feda28b6ac7924cde9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding=utf-8
"""
LICENSE http://www.apache.org/licenses/LICENSE-2.0
"""
import datetime
import sys
import time
import threading
import traceback
import SocketServer
from dnslib import *
from collections import Counter
import redis
class DomainName(str):
def __getattr__(self, item):
return DomainName(item + '.' + self)
D = DomainName('l.vm-0.com.')
IP = '159.203.173.33'
NS_IP = '10.17.0.5'
TTL = 5
PORT = 53
soa_record = SOA(
mname=DomainName("float.vm-0.com."), # primary name server
rname=DomainName("vladislav.vm-0.com."), # email of the domain administrator
times=(
201307231, # serial number
60 * 60 * 1, # refresh
60 * 60 * 3, # retry
60 * 60 * 24, # expire
60 * 60 * 1, # minimum
)
)
ns_records = [NS(DomainName("float.vm-0.com."))]
records = {
D: [A(IP), soa_record] + ns_records,
D.test: [A(IP)],
}
ctr = Counter()
redis_pool = redis.BlockingConnectionPool(max_connections=10, timeout=1)
def dns_response(request):
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=1, ra=1), q=request.q)
qname = request.q.qname
qn = str(qname).lower()
qtype = request.q.qtype
qt = QTYPE[qtype]
if qn == D or qn.endswith('.' + D):
if qn in records:
rrs = records[qn]
for rdata in rrs:
rqt = rdata.__class__.__name__
if qt in ['*', 'ANY', rqt] or rqt in ('CNAME', 'DNAME'):
reply.add_answer(RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata))
else:
# catchall
if qt in ['*', 'ANY', 'A']:
reply.add_answer(RR(rname=qname, rtype=QTYPE.A, rclass=1, ttl=TTL, rdata=A(IP)))
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
for rdata in ns_records:
reply.add_auth(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
return reply.pack()
class BaseRequestHandler(SocketServer.BaseRequestHandler):
def get_data(self):
raise NotImplementedError
def send_data(self, data):
raise NotImplementedError
def handle(self):
ts = time.time()
now = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')
try:
data = self.get_data()
request = DNSRecord.parse(data)
qt = QTYPE[request.q.qtype]
qn = str(request.q.qname).lower()
qn = qn[4:] if qn.startswith('www.') else qn
parts = qn.split('.')
print "%s request %s (%s %s): %s %s?" % (self.__class__.__name__[:3], now, self.client_address[0],
self.client_address[1], request.q.qname, QTYPE[request.q.qtype])
if parts[0].startswith('drop'):
drop = 0
try:
drop = int(parts[0][4:])
except:
pass
if drop <= ctr[(qn,qt)]:
del ctr[(qn,qt)]
self.send_data(dns_response(request))
else:
ctr.update(((qn, qt),))
else:
self.send_data(dns_response(request))
tags = [ part for part in parts if part.startswith('tag') ]
if tags:
r = redis.StrictRedis(connection_pool=redis_pool)
for tag in tags:
key = "dnsleak:" + tag
r.rpush(key, "%f\0%s" % (ts, self.client_address[0]))
r.expire(key, 3600)
except Exception:
traceback.print_exc(file=sys.stderr)
class TCPRequestHandler(BaseRequestHandler):
def get_data(self):
data = self.request.recv(8192).strip()
sz = int(data[:2].encode('hex'), 16)
if sz < len(data) - 2:
raise Exception("Wrong size of TCP packet")
elif sz > len(data) - 2:
raise Exception("Too big TCP packet")
return data[2:]
def send_data(self, data):
sz = hex(len(data))[2:].zfill(4).decode('hex')
return self.request.sendall(sz + data)
class UDPRequestHandler(BaseRequestHandler):
def get_data(self):
return self.request[0].strip()
def send_data(self, data):
return self.request[1].sendto(data, self.client_address)
if __name__ == '__main__':
print "Starting nameserver..."
servers = [
SocketServer.ThreadingUDPServer((NS_IP, PORT), UDPRequestHandler),
SocketServer.ThreadingTCPServer((NS_IP, PORT), TCPRequestHandler),
]
for s in servers:
thread = threading.Thread(target=s.serve_forever) # that thread will start one more thread for each request
thread.daemon = True # exit the server thread when the main thread terminates
thread.start()
print "%s server loop running in thread: %s" % (s.RequestHandlerClass.__name__[:3], thread.name)
try:
while 1:
time.sleep(1)
sys.stderr.flush()
sys.stdout.flush()
except KeyboardInterrupt:
pass
finally:
for s in servers:
s.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment