Skip to content

Instantly share code, notes, and snippets.

@sorz
Last active July 17, 2017 22:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sorz/5bc7bdb8dcfdcbb5df4a8bcffeb6da8e to your computer and use it in GitHub Desktop.
Save sorz/5bc7bdb8dcfdcbb5df4a8bcffeb6da8e to your computer and use it in GitHub Desktop.
Forward DNS query (in UDP) to a set of servers in parallel, then send back the fastest reply and ignore others. TCP queries are directly forward to a fixed server.
#!/usr/bin/env python3
"""
Forward DNS query (in UDP) to a set of servers in parallel, then send back
the fastest reply and ignore others. TCP queries are directly forward to a
fixed server.
"""
import asyncio
import socket
LISTEN = ('127.0.0.1', 5353)
UDP_SERVERS = [('127.0.0.1', port) for port in range(8130, 8135)]
TCP_SERVER = ('8.8.8.8', 53)
UDP_MAX_WAIT = 5
TCP_TIMEOUT = 10
class UDPServerProtocol:
def connection_made(self, transport):
self.transport = transport
self.loop = asyncio.get_event_loop()
def datagram_received(self, data, addr):
print(f'receive {len(data)} bytes from {addr}')
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
for remote in UDP_SERVERS:
sock.sendto(data, remote)
coro = self.loop.sock_recv(sock, 8192)
task = self.loop.create_task(asyncio.wait_for(coro, UDP_MAX_WAIT))
def done(task):
if task.exception() is not None:
if isinstance(task.exception(), asyncio.TimeoutError):
print('no any reply from all servers')
else:
print(f'error on receiving udp: {task.exception()}')
else:
resp = task.result()
print(f'receive response {len(resp)} bytes')
self.transport.sendto(resp, addr)
sock.close()
task.add_done_callback(done)
async def tcp_forwarder(local_r, local_w):
peername = local_w.get_extra_info('peername')
print(f'tcp connected from {peername}')
remote_r, remote_w = await asyncio.open_connection(*TCP_SERVER)
#print(f'connected to remote')
pipings = [piping(local_r, remote_w), piping(remote_r, local_w)]
await asyncio.wait(pipings, return_when=asyncio.FIRST_EXCEPTION)
local_w.close()
remote_w.close()
#print(f'tcp close')
async def piping(reader, writer):
while True:
data = await asyncio.wait_for(reader.read(8196), TCP_TIMEOUT)
if not data:
writer.write_eof()
return
writer.write(data)
await writer.drain()
def main():
loop = asyncio.get_event_loop()
listen = loop.create_datagram_endpoint(UDPServerProtocol, LISTEN)
udp, proto = loop.run_until_complete(listen)
listen = asyncio.start_server(tcp_forwarder, *LISTEN)
tcp = loop.run_until_complete(listen)
print('running')
try:
loop.run_forever()
except KeyboardInterrupt:
pass
udp.close()
tcp.close()
loop.close()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment