This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import sys | |
import logging | |
import pathlib | |
import os | |
from Client import Client | |
from datetime import datetime | |
class Server: | |
def __init__(self, ip: str, port: int, loop: asyncio.AbstractEventLoop): | |
''' | |
Parameters | |
---------- | |
ip : str | |
IP that the server will be using | |
port : int | |
Port that the server will be using | |
---------- | |
''' | |
self.__ip: str = ip | |
self.__port: int = port | |
self.__loop: asyncio.AbstractEventLoop = loop | |
self.__logger: logging.Logger = self.initialize_logger() | |
self.__clients: dict[asyncio.Task, Client] = {} | |
self.logger.info(f"Server Initialized with {self.ip}:{self.port}") | |
@property | |
def ip(self): | |
return self.__ip | |
@property | |
def port(self): | |
return self.__port | |
@property | |
def loop(self): | |
return self.__loop | |
@property | |
def logger(self): | |
return self.__logger | |
@property | |
def clients(self): | |
return self.__clients | |
def initialize_logger(self): | |
''' | |
Initializes a logger and generates a log file in ./logs. | |
Returns | |
------- | |
logging.Logger | |
Used for writing logs of varying levels to the console and log file. | |
------- | |
''' | |
path = pathlib.Path(os.path.join(os.getcwd(), "logs")) | |
path.mkdir(parents=True, exist_ok=True) | |
logger = logging.getLogger('Server') | |
logger.setLevel(logging.DEBUG) | |
ch = logging.StreamHandler() | |
fh = logging.FileHandler( | |
filename=f'logs/{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}_server.log' | |
) | |
ch.setLevel(logging.INFO) | |
fh.setLevel(logging.DEBUG) | |
formatter = logging.Formatter( | |
'[%(asctime)s] - %(levelname)s - %(message)s' | |
) | |
ch.setFormatter(formatter) | |
fh.setFormatter(formatter) | |
logger.addHandler(ch) | |
logger.addHandler(fh) | |
return logger | |
def start_server(self): | |
''' | |
Starts the server on IP and PORT. | |
''' | |
try: | |
self.server = asyncio.start_server( | |
self.accept_client, self.ip, self.port | |
) | |
self.loop.run_until_complete(self.server) | |
self.loop.run_forever() | |
except Exception as e: | |
self.logger.error(e) | |
except KeyboardInterrupt: | |
self.logger.warning("Keyboard Interrupt Detected. Shutting down!") | |
self.shutdown_server() | |
def accept_client(self, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter): | |
''' | |
Callback that is used when server accepts clients | |
Parameters | |
---------- | |
client_reader : asyncio.StreamReader | |
StreamReader generated by asyncio.start_server. | |
client_writer : asyncio.StreamWriter | |
StreamWriter generated by asyncio.start_server. | |
---------- | |
''' | |
client = Client(client_reader, client_writer) | |
task = asyncio.Task(self.handle_client(client)) | |
self.clients[task] = client | |
client_ip = client_writer.get_extra_info('peername')[0] | |
client_port = client_writer.get_extra_info('peername')[1] | |
self.logger.info(f"New Connection: {client_ip}:{client_port}") | |
task.add_done_callback(self.disconnect_client) | |
async def handle_client(self, client: Client): | |
''' | |
Handles incoming messages from client | |
Parameters | |
---------- | |
client_reader : asyncio.StreamReader | |
StreamReader generated by asyncio.start_server. | |
client_writer : asyncio.StreamWriter | |
StreamWriter generated by asyncio.start_server. | |
---------- | |
''' | |
while True: | |
client_message = await client.get_message() | |
if client_message.startswith("quit"): | |
break | |
elif client_message.startswith("/"): | |
self.handle_client_command(client, client_message) | |
else: | |
self.broadcast_message( | |
f"{client.nickname}: {client_message}".encode('utf8')) | |
self.logger.info(f"{client_message}") | |
await client.writer.drain() | |
self.logger.info("Client Disconnected!") | |
def handle_client_command(self, client: Client, client_message: str): | |
client_message = client_message.replace("\n", "").replace("\r", "") | |
if client_message.startswith("/nick"): | |
split_client_message = client_message.split(" ") | |
if len(split_client_message) >= 2: | |
client.nickname = split_client_message[1] | |
client.writer.write( | |
f"Nickname changed to {client.nickname}\n".encode('utf8')) | |
return | |
client.writer.write("Invalid Command\n".encode('utf8')) | |
def broadcast_message(self, message: bytes, exclusion_list: list = []): | |
''' | |
Parameters | |
---------- | |
message : bytes | |
A message consisting of utf8 bytes to broadcast to all clients. | |
OPTIONAL exclusion_list : list[Client] | |
A list of clients to exclude from receiving the provided message. | |
---------- | |
''' | |
for client in self.clients.values(): | |
if client not in exclusion_list: | |
client.writer.write(message) | |
def disconnect_client(self, task: asyncio.Task): | |
''' | |
Disconnects and broadcasts to the other clients that a client has been disconnected. | |
Parameters | |
---------- | |
task : asyncio.Task | |
The task object associated with the client generated during self.accept_client() | |
---------- | |
''' | |
client = self.clients[task] | |
self.broadcast_message( | |
f"{client.nickname} has left!".encode('utf8'), [client]) | |
del self.clients[task] | |
client.writer.write('quit'.encode('utf8')) | |
client.writer.close() | |
self.logger.info("End Connection") | |
def shutdown_server(self): | |
''' | |
Shuts down server. | |
''' | |
self.logger.info("Shutting down server!") | |
for client in self.clients.values(): | |
client.writer.write('quit'.encode('utf8')) | |
self.loop.stop() | |
if __name__ == '__main__': | |
if len(sys.argv) < 3: | |
sys.exit(f"Usage: {sys.argv[0]} HOST_IP PORT") | |
loop = asyncio.get_event_loop() | |
server = Server(sys.argv[1], sys.argv[2], loop) | |
server.start_server() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment