Created
November 21, 2021 07:31
-
-
Save inactive123/1613a96ac82d28fbb54aaf26768b9741 to your computer and use it in GitHub Desktop.
retroarch_tunnel_server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# RetroArch - A frontend for libretro. | |
# Copyright (C) 2021 - The RetroArch team | |
# | |
# RetroArch is free software: you can redistribute it and/or modify it under the terms | |
# of the GNU General Public License as published by the Free Software Found- | |
# ation, either version 3 of the License, or (at your option) any later version. | |
# | |
# RetroArch is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; | |
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR | |
# PURPOSE. See the GNU General Public License for more details. | |
# | |
# You should have received a copy of the GNU General Public License along with RetroArch. | |
# If not, see <http://www.gnu.org/licenses/>. | |
from __future__ import annotations | |
from typing import ( | |
Optional, | |
Union, | |
ClassVar, | |
Tuple, | |
List, | |
Dict, | |
cast | |
) | |
import sys | |
import os | |
import enum | |
import socket | |
import asyncio | |
import configparser | |
from abc import ABC, abstractmethod | |
from datetime import datetime | |
from pathlib import Path | |
class Error(Exception): | |
pass | |
@enum.unique | |
class LogLevel(enum.IntEnum): | |
NONE = 0 | |
ERROR = 1 | |
WARN = 2 | |
INFO = 3 | |
class Logger: | |
__slots__ = ("__path", "__level", "__lock") | |
__path: Path | |
__level: LogLevel | |
__lock: asyncio.Lock | |
def __init__(self, path: Path, level: LogLevel = LogLevel.NONE): | |
self.__path = path | |
self.__level = level | |
self.__lock = asyncio.Lock() | |
async def __call__(self, level: LogLevel, message: str) -> None: | |
if level is LogLevel.NONE or not message: | |
return # Should this be silent? | |
if self.__level >= level: | |
now: datetime = datetime.now() | |
async with self.__lock: | |
try: | |
with self.__path.open("a", encoding = "UTF-8") as log: | |
log.write(f"({now:%d}/{now:%m}/{now:%Y} {now:%H}:{now:%M}:{now:%S}) [{level.name}] {message}\n") | |
except OSError: | |
pass # Should this be silent? | |
class ConfigError(Error): | |
pass | |
class Config: | |
__slots__ = ( | |
"__server_port", | |
"__server_timeout", | |
"__session_max", | |
"__session_clients", | |
"__log_path", | |
"__log_level" | |
) | |
__server_port: int | |
__server_timeout: float | |
__session_max: int | |
__session_clients: int | |
__log_path: Path | |
__log_level: LogLevel | |
def __init__(self, path: Optional[Path] = None): | |
ini: configparser.ConfigParser = configparser.ConfigParser( | |
delimiters = '=', | |
comment_prefixes = ';', | |
interpolation = None | |
) | |
try: | |
if not ini.read(path if path is not None else Path(sys.argv[0]).with_suffix(".ini")): | |
raise ConfigError("Configuration not found.") | |
try: | |
server: configparser.SectionProxy = ini["Server"] | |
except KeyError: | |
raise ConfigError("Server section not found.") | |
else: | |
try: | |
self.__server_port = int(server["Port"]) | |
if self.__server_port not in range(1, 65536): | |
raise ValueError | |
except KeyError: | |
raise ConfigError("Server port not found.") | |
except ValueError: | |
raise ConfigError("Invalid server port.") | |
try: | |
self.__server_timeout = float(server["Timeout"]) | |
if self.__server_timeout <= 0: | |
raise ValueError | |
except KeyError: | |
raise ConfigError("Server timeout not found.") | |
except ValueError: | |
raise ConfigError("Invalid server timeout.") | |
try: | |
session: configparser.SectionProxy = ini["Session"] | |
except KeyError: | |
raise ConfigError("Session section not found.") | |
else: | |
try: | |
self.__session_max = int(session["Max"]) | |
if self.__session_max < 0: | |
raise ValueError | |
except KeyError: | |
raise ConfigError("Session max not found.") | |
except ValueError: | |
raise ConfigError("Invalid session max.") | |
try: | |
self.__session_clients = int(session["Clients"]) | |
if self.__session_clients < 0: | |
raise ValueError | |
except KeyError: | |
raise ConfigError("Session clients not found.") | |
except ValueError: | |
raise ConfigError("Invalid session clients.") | |
try: | |
log: configparser.SectionProxy = ini["Log"] | |
except KeyError: | |
raise ConfigError("Log section not found.") | |
else: | |
try: | |
self.__log_path: Path = Path(log["Path"]) | |
except KeyError: | |
raise ConfigError("Log path not found.") | |
try: | |
level_name: str = log["Level"].upper() | |
except KeyError: | |
raise ConfigError("Log level not found.") | |
try: | |
self.__log_level = LogLevel[level_name] | |
except KeyError: | |
raise ConfigError("Invalid log level.") | |
except configparser.Error: | |
raise ConfigError("Invalid configuration.") | |
@property | |
def port(self) -> int: return self.__server_port | |
@property | |
def timeout(self) -> float: return self.__server_timeout | |
@property | |
def max_sessions(self) -> int: return self.__session_max | |
@property | |
def max_clients_per_session(self) -> int: return self.__session_clients | |
@property | |
def log_path(self) -> Path: return self.__log_path | |
@property | |
def log_level(self) -> LogLevel: return self.__log_level | |
class TunnelError(Error): | |
pass | |
class TunnelMagic: | |
SESSION: ClassVar[bytes] = b'RATS' | |
LINK: ClassVar[bytes] = b'RATL' | |
PING: ClassVar[bytes] = b'RATP' | |
@staticmethod | |
def size() -> int: return 4 | |
@staticmethod | |
def unique_size() -> int: return 12 | |
class Tunnel(ABC): | |
__slots__ = ("_config", "_logger", "_lock") | |
_config: Config | |
_logger: Logger | |
_lock: asyncio.Lock | |
def __init__(self, config: Config, logger: Logger): | |
self._config = config | |
self._logger = logger | |
self._lock = asyncio.Lock() | |
async def log_info(self, message: str) -> None: | |
await self._logger(LogLevel.INFO, message) | |
async def log_warn(self, message: str) -> None: | |
await self._logger(LogLevel.WARN, message) | |
async def log_error(self, message: str) -> None: | |
await self._logger(LogLevel.ERROR, message) | |
@abstractmethod | |
async def __call__(self) -> None: | |
pass | |
class TunnelClient(Tunnel): | |
__slots__ = ( | |
"_reader", | |
"_writer", | |
"_unique" | |
) | |
_reader: asyncio.StreamReader | |
_writer: asyncio.StreamWriter | |
_unique: bytes | |
def __init__(self, config: Config, logger: Logger, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
super().__init__(config, logger) | |
self._reader = reader | |
self._writer = writer | |
self._unique = b"" | |
def __del__(self) -> None: | |
try: | |
self._writer.close() | |
except Exception: | |
pass | |
@property | |
@abstractmethod | |
def session_owner(self) -> bool: | |
pass | |
@property | |
def address(self) -> str: | |
return cast(str, self._writer.get_extra_info("peername")[0]) | |
@property | |
def port(self) -> int: | |
return cast(int, self._writer.get_extra_info("peername")[1]) | |
async def unique(self) -> bytes: | |
async with self._lock: | |
return self._unique | |
async def set_unique(self, unique: bytes) -> None: | |
async with self._lock: | |
self._unique = unique | |
class SessionOwner(TunnelClient): | |
__slots__ = ("__connected", "__ready") | |
__connected: int | |
__ready: bool | |
def __init__(self, config: Config, logger: Logger, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
super().__init__(config, logger, reader, writer) | |
self.__connected = 1 # Host counts as one. | |
self.__ready = False | |
def __bool__(self) -> bool: | |
return self.__ready | |
@property | |
def session_owner(self) -> bool: | |
return True | |
async def request_link(self, user: SessionUser) -> bool: | |
async with self._lock: | |
if not self.__ready: | |
return False | |
if self._reader.at_eof(): | |
return False | |
if self._writer.is_closing(): | |
return False | |
if self._config.max_clients_per_session and self.__connected >= self._config.max_clients_per_session: | |
return False | |
# Tell the host to establish a new link connection. | |
try: | |
self._writer.write(TunnelMagic.LINK + await user.unique()) | |
await self._writer.drain() | |
except Exception: | |
return False | |
else: | |
self.__connected += 1 | |
return True | |
async def request_unlink(self) -> bool: | |
async with self._lock: | |
if not self.__ready: | |
return False | |
self.__connected -= 1 | |
return True | |
async def __call__(self) -> None: | |
try: | |
# The first thing is sending the host his session id. | |
async with self._lock: | |
self._writer.write(TunnelMagic.SESSION + self._unique) | |
await self._writer.drain() | |
# We are ready to receive connection requests. | |
self.__ready = True | |
ping_magic: bytes | |
ping_requests: int = 0 | |
while not self._reader.at_eof(): | |
try: | |
# We want to request a ping every minute. | |
ping_magic = await asyncio.wait_for(self._reader.readexactly(len(TunnelMagic.PING)), 60) | |
except asyncio.TimeoutError: | |
# Attempt to ping the host a total of 3 times before timeouting. | |
if ping_requests < 3: | |
async with self._lock: | |
self._writer.write(TunnelMagic.PING) | |
await self._writer.drain() | |
ping_requests += 1 | |
else: | |
await self.log_error(f"Tunnel session timeout for: {self.address}|{self.port}") | |
break | |
else: | |
# We received something, but we are only allowing for ping responses. | |
if ping_magic == TunnelMagic.PING: | |
ping_requests = 0 | |
else: | |
await self.log_error(f"Tunnel session received non-ping data for: {self.address}|{self.port}") | |
break | |
except Exception: | |
pass | |
async with self._lock: | |
self.__ready = False | |
await self.log_info(f"Tunnel session closed for: {self.address}|{self.port}") | |
class SessionUser(TunnelClient): | |
__slots__ = ("__host", "__owner", "__link", "__linked") | |
__host: bool | |
__owner: Optional[SessionOwner] | |
__link: Optional[SessionUser] | |
__linked: asyncio.Event | |
def __init__(self, config: Config, logger: Logger, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, host: bool = False): | |
super().__init__(config, logger, reader, writer) | |
self.__host = host | |
self.__owner = None | |
self.__link = None | |
self.__linked = asyncio.Event() | |
def __bool__(self) -> bool: | |
return self.__linked.is_set() | |
async def __forward(self, data: bytes) -> bool: | |
# Forward data from one peer to another through our tunnel. | |
try: | |
async with self._lock: | |
self._writer.write(data) | |
await self._writer.drain() | |
except Exception: | |
return False | |
else: | |
return True | |
@property | |
def session_owner(self) -> bool: | |
return False | |
@property | |
def is_host(self) -> bool: | |
return self.__host | |
async def owner(self) -> Optional[SessionOwner]: | |
async with self._lock: | |
return self.__owner | |
async def set_owner(self, owner: Optional[SessionOwner]) -> None: | |
async with self._lock: | |
self.__owner = owner | |
async def try_link(self, link: SessionUser) -> bool: | |
owner: Optional[SessionOwner] = await self.owner() | |
if owner is None: | |
return False | |
async with owner._lock: | |
async with self._lock: | |
async with link._lock: | |
if self.__link is not None or link.__link is not None: | |
return False | |
if self.__linked.is_set() or link.__linked.is_set(): | |
return False | |
if self._reader.at_eof() or link._reader.at_eof(): | |
return False | |
if self._writer.is_closing() or link._writer.is_closing(): | |
return False | |
self.__link = link | |
link.__link = self | |
self.__linked.set() | |
link.__linked.set() | |
return True | |
async def try_unlink(self) -> bool: | |
owner: Optional[SessionOwner] = await self.owner() | |
if owner is None: | |
return False | |
async with owner._lock: | |
async with self._lock: | |
link: Optional[SessionUser] = self.__link | |
if link is None: | |
return False | |
async with link._lock: | |
self.__link = None | |
link.__link = None | |
return True | |
async def __call__(self) -> None: | |
timeout: float = self._config.timeout | |
try: | |
# Wait until we are linked to a connection or a timeout occurs. | |
await asyncio.wait_for(self.__linked.wait(), timeout) | |
except asyncio.TimeoutError: | |
await self.log_error(f"Timeout while awaiting link for: {self.address}|{self.port}") | |
else: | |
# The tunnel is ready. | |
try: | |
data: bytes | |
link: Optional[SessionUser] | |
# Forward data until the connection is closed or until a timeout occurs. | |
while not self._reader.at_eof(): | |
try: | |
data = await asyncio.wait_for(self._reader.read(8192), timeout) | |
except asyncio.TimeoutError: | |
await self.log_error(f"Tunnel link timeout for: {self.address}|{self.port}") | |
break | |
if not data: | |
break | |
async with self._lock: | |
# Check if our link is still around. | |
link = self.__link | |
if link is None: | |
break | |
if not await link.__forward(data): | |
await self.log_error(f"Failed to forward data from: {self.address}|{self.port}") | |
break | |
except Exception: | |
pass | |
await self.log_info(f"Tunnel closed for: {self.address}|{self.port}") | |
class TunnelServer(Tunnel): | |
__slots__ = ("__clients", "__sessions") | |
__clients: Dict[bytes, TunnelClient] | |
__sessions: int | |
def __init__(self, config: Config, logger: Logger): | |
super().__init__(config, logger) | |
self.__clients = {} | |
self.__sessions = 0 | |
async def __request_session(self) -> bool: | |
if self._config.max_sessions: | |
async with self._lock: | |
if self.__sessions >= self._config.max_sessions: | |
return False | |
self.__sessions += 1 | |
return True | |
async def __free_session(self) -> None: | |
if self.__sessions > 0: | |
async with self._lock: | |
self.__sessions -= 1 | |
async def __add_client(self, client: TunnelClient) -> None: | |
unique: bytes | |
invalid_unique: bytes = bytes(bytearray(TunnelMagic.unique_size())) | |
# Generate a new unique id for this client. | |
async with self._lock: | |
while True: | |
unique = os.urandom(TunnelMagic.unique_size()) | |
if unique == invalid_unique: | |
continue | |
if unique not in self.__clients: | |
break | |
await client.set_unique(unique) | |
self.__clients[unique] = client | |
async def __remove_client(self, client: TunnelClient) -> None: | |
unique: bytes = await client.unique() | |
async with self._lock: | |
try: | |
del self.__clients[unique] | |
except KeyError: | |
await self.log_warn(f"Failed to remove client: {client.address}|{client.port}") | |
async def __session_link_request(self, user: SessionUser, session_id: bytes) -> Optional[SessionOwner]: | |
unique: bytes = await user.unique() | |
if session_id == unique: | |
return await self.log_error(f"Invalid session link request from: {user.address}|{user.port}") | |
try: | |
owner: TunnelClient = self.__clients[session_id] | |
except KeyError: | |
return await self.log_error(f"Failed to find session for: {user.address}|{user.port}") | |
else: | |
if not owner.session_owner: | |
return await self.log_error(f"Invalid session id from: {user.address}|{user.port}") | |
await user.set_owner(cast(SessionOwner, owner)) | |
if not await cast(SessionOwner, owner).request_link(user): | |
return await self.log_error(f"Failed to establish tunnel link for: {user.address}|{user.port}") | |
return cast(SessionOwner, owner) | |
async def __session_link(self, user: SessionUser, peer_id: bytes) -> Optional[SessionUser]: | |
unique: bytes = await user.unique() | |
if peer_id == unique: | |
return await self.log_error(f"Invalid peer link request from: {user.address}|{user.port}") | |
try: | |
peer: TunnelClient = self.__clients[peer_id] | |
except KeyError: | |
return await self.log_error(f"Failed to find peer for: {user.address}|{user.port}") | |
else: | |
if peer.session_owner or cast(SessionUser, peer).is_host: | |
return await self.log_error(f"Invalid peer id from: {user.address}|{user.port}") | |
owner: Optional[SessionOwner] = await cast(SessionUser, peer).owner() | |
if owner is None: | |
return await self.log_error(f"Invalid peer session owner for: {user.address}|{user.port}") | |
await user.set_owner(owner) | |
if not await cast(SessionUser, peer).try_link(user): | |
return await self.log_error(f"Failed to establish tunnel link for: {user.address}|{user.port}") | |
return cast(SessionUser, peer) | |
async def __handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: | |
try: | |
sock_info: Tuple[Union[str, int], ...] = writer.get_extra_info("peername") | |
addr: str = cast(str, sock_info[0]) | |
port: int = cast(int, sock_info[1]) | |
await self.log_info(f"Received connection from: {addr}|{port}") | |
try: | |
writer.get_extra_info("socket").setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, True) | |
except Exception: | |
await self.log_warn(f"Failed to set TCP_NODELAY for: {addr}|{port}") | |
try: | |
# Do not wait for more than 30 seconds for the magic. | |
magic: bytes = await asyncio.wait_for(reader.readexactly(TunnelMagic.size()), 30) | |
except Exception: | |
return await self.log_error(f"Failed to receive tunnel magic from: {addr}|{port}") | |
if magic == b'RANP': | |
# Client does not support tunnels. | |
# Send a fake header to let it know it's outdated. | |
writer.write(b'RANP' + bytearray(20)) | |
await writer.drain() | |
return await self.log_error(f"Unsupported client from: {addr}|{port}") | |
elif magic in (TunnelMagic.SESSION, TunnelMagic.LINK): | |
try: | |
# Do not wait for more than 30 seconds for the unique id. | |
unique: bytes = await asyncio.wait_for(reader.readexactly(TunnelMagic.unique_size()), 30) | |
except Exception: | |
return await self.log_error(f"Failed to receive tunnel unique id from: {addr}|{port}") | |
client: TunnelClient | |
owner: Optional[SessionOwner] | |
peer: Optional[SessionUser] | |
if magic == TunnelMagic.SESSION: | |
if unique == bytearray(TunnelMagic.unique_size()): | |
# All zeros. | |
# Client requested a new session. | |
if await self.__request_session(): | |
client = SessionOwner(self._config, self._logger, reader, writer) | |
await self.__add_client(client) | |
await self.log_info(f"Tunnel session created for: {addr}|{port}") | |
else: | |
return await self.log_error(f"Refused to create tunnel session for: {addr}|{port}") | |
else: | |
# Client trying to link to an existing session. | |
client = SessionUser(self._config, self._logger, reader, writer, host = False) | |
await self.__add_client(client) | |
owner = await self.__session_link_request(client, unique) | |
if owner is not None: | |
await self.log_info(f"Pending tunnel linking for: {addr}|{port}") | |
else: | |
await self.__remove_client(client) | |
return await self.log_error(f"Tunnel linking failed for: {addr}|{port}") | |
else: | |
# Session host client trying to link to an user client. | |
client = SessionUser(self._config, self._logger, reader, writer, host = True) | |
await self.__add_client(client) | |
peer = await self.__session_link(client, unique) | |
if peer is not None: | |
await self.log_info( | |
f"Tunnel linking completed for: {peer.address}|{peer.port} <-> {addr}|{port}" | |
) | |
else: | |
await self.__remove_client(client) | |
return await self.log_error(f"Tunnel linking failed for: {addr}|{port}") | |
await client() | |
await self.__remove_client(client) | |
if client.session_owner: | |
await self.__free_session() | |
else: | |
# Make sure to tell our link we are done for. | |
await cast(SessionUser, client).try_unlink() | |
if not cast(SessionUser, client).is_host: | |
await cast(SessionOwner, owner).request_unlink() | |
else: | |
return await self.log_error(f"Unknown tunnel magic from: {addr}|{port}") | |
finally: | |
# Close the connection once we are done. | |
try: | |
writer.write_eof() | |
except Exception: | |
pass | |
writer.close() | |
await writer.wait_closed() | |
await self.log_info(f"Connection closed for: {addr}|{port}") | |
async def __call__(self) -> None: | |
server: asyncio.AbstractServer | |
try: | |
if sys.platform == "win32": | |
server = await asyncio.start_server(self.__handle_client, port = self._config.port) | |
else: | |
server = await asyncio.start_server(self.__handle_client, port = self._config.port, reuse_port = True) | |
except Exception: | |
raise TunnelError("Failed to create tunnel server.") | |
async with server: | |
sock_info: Tuple[Union[str, int], ...] | |
address: str | |
for sock in cast(List[socket.socket], server.sockets): | |
sock_info = sock.getsockname() | |
address = f"{sock_info[0]}|{sock_info[1]}" | |
print(f"[*] Tunnel server listening on: {address}") | |
await self.log_info(f"Tunnel server listening on: {address}") | |
await server.serve_forever() | |
async def main() -> None: | |
config: Config = Config(Path(sys.argv[1].strip()) if len(sys.argv) > 1 else None) | |
logger: Logger = Logger(config.log_path, config.log_level) | |
server: TunnelServer = TunnelServer(config, logger) | |
await server() | |
if __name__ == "__main__": | |
asyncio.run(main()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment