Created
March 6, 2021 18:43
-
-
Save psychon/b3272c4d51383015657f6f854a4caad6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import select | |
import socket | |
import sys | |
import time | |
class Listener: | |
def __init__(self, listen_port, target_port, delay): | |
self.fd = socket.create_server(("127.0.0.1", listen_port)) | |
self.fd.settimeout(0) | |
self.target_port = target_port | |
self.delay = delay | |
def register_fds(self, rlist, wlist): | |
rlist.append(self.fd.fileno()) | |
def process(self, readable, writable): | |
if self.fd.fileno() in readable: | |
client, addr = self.fd.accept() | |
print(f"Incoming connection from {addr}") | |
server = socket.create_connection(("127.0.0.1", self.target_port)) | |
client.settimeout(0) | |
server.settimeout(0) | |
return [Forwarder(client, server, self.delay), Forwarder(server, client, self.delay)] | |
class Forwarder: | |
def __init__(self, read, write, delay): | |
self.read = read | |
self.write = write | |
self.delay = delay | |
self.pending_data = [] | |
def register_fds(self, rlist, wlist): | |
rlist.append(self.read.fileno()) | |
if self.pending_data: | |
timeout = self.pending_data[0][0] - time.time() | |
if timeout <= 0: | |
wlist.append(self.write.fileno()) | |
else: | |
return timeout | |
def process(self, readable, writable): | |
if self.write.fileno() in writable: | |
delay, data = self.pending_data[0] | |
sent = self.write.send(data) | |
data = data[sent:] | |
if data: | |
self.pending_data[0] = (delay, data) | |
else: | |
self.pending_data.pop(0) | |
if self.read.fileno() in readable: | |
try: | |
data = self.read.recv(4096) | |
except OSError as e: | |
print("Error during socket.recv():") | |
print(e) | |
self.shutdown_write() | |
return True | |
self.pending_data.append((time.time() + self.delay, data)) | |
if not data: | |
self.shutdown_write() | |
return True | |
def shutdown_write(self): | |
try: | |
self.write.shutdown(socket.SHUT_WR) | |
except OSError as e: | |
print("Error during socket.shutdown():") | |
print(e) | |
def main(listen_port, target_port, delay): | |
sockets = [Listener(listen_port, target_port, delay)] | |
while True: | |
rlist, wlist = [], [] | |
timeout = None | |
for sock in sockets: | |
t = sock.register_fds(rlist, wlist) | |
if t and (timeout is None or t < timeout): | |
timeout = t | |
readable, writeable, _ = select.select(rlist, wlist, [], timeout) | |
new_sockets = [] | |
for sock in sockets: | |
new_socks = sock.process(readable, writeable) | |
if new_socks is None: | |
new_sockets.append(sock) | |
elif isinstance(new_socks, list): | |
new_sockets.append(sock) | |
new_sockets += new_socks | |
elif isinstance(new_socks, bool) and new_socks: | |
# Socket is not added to new_sockets | |
pass | |
else: | |
print(new_socks) | |
error() | |
sockets = new_sockets | |
if __name__ == "__main__": | |
print("This program is a TCP proxy which adds a delay to the connection") | |
print(f"Usage: {sys.argv[0]} [port to listen on] [port on localhost to connect to] [delay in seconds]") | |
print("Note that the round-trip delay will be twice the given delay.") | |
main(int(sys.argv[1]), int(sys.argv[2]), float(sys.argv[3])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment