-
-
Save D3ISM3/48a5b703ee902f0b15202fcab38862ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import sys | |
import logging | |
import pathlib | |
import os | |
from ClientModel import Client | |
from datetime import datetime | |
class Server: | |
def __init__(self, ip: str, port: int, loop: asyncio.AbstractEventLoop): | |
''' | |
Parameters | |
---------- | |
ip : str | |
IP that the server will be using | |
port : int | |
Port that the server will be using | |
loop : asyncio.AbstractEventLoop | |
Asyncio's running event loop from asyncio.get_event_loop() | |
---------- | |
''' | |
self.__ip: str = ip | |
self.__port: int = port | |
self.__loop: asyncio.AbstractEventLoop = loop | |
self.__logger: logging.Logger = self.initialize_logger() | |
self.__clients: dict[asyncio.Task, Client] = {} | |
self.logger.info(f"Server Initialized with {self.ip}:{self.port}") | |
@property | |
def ip(self): | |
return self.__ip | |
@property | |
def port(self): | |
return self.__port | |
@property | |
def loop(self): | |
return self.__loop | |
@property | |
def logger(self): | |
return self.__logger | |
@property | |
def clients(self): | |
return self.__clients | |
def initialize_logger(self): | |
''' | |
Initializes a logger and generates a log file in ./logs. | |
Returns | |
------- | |
logging.Logger | |
Used for writing logs of varying levels to the console and log file. | |
------- | |
''' | |
path = pathlib.Path(os.path.join(os.getcwd(), "logs")) | |
path.mkdir(parents=True, exist_ok=True) | |
logger = logging.getLogger('Server') | |
logger.setLevel(logging.DEBUG) | |
ch = logging.StreamHandler() | |
fh = logging.FileHandler( | |
filename=f'logs/{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}_server.log' | |
) | |
ch.setLevel(logging.INFO) | |
fh.setLevel(logging.DEBUG) | |
formatter = logging.Formatter( | |
'[%(asctime)s] - %(levelname)s - %(message)s' | |
) | |
ch.setFormatter(formatter) | |
fh.setFormatter(formatter) | |
logger.addHandler(ch) | |
logger.addHandler(fh) | |
return logger | |
def start_server(self): | |
''' | |
Starts the server on IP and PORT. | |
''' | |
try: | |
self.server = asyncio.start_server( | |
self.accept_client, self.ip, self.port | |
) | |
self.loop.run_until_complete(self.server) | |
self.loop.run_forever() | |
except Exception as e: | |
self.logger.error(e) | |
except KeyboardInterrupt: | |
self.logger.warning("Keyboard Interrupt Detected. Shutting down!") | |
self.shutdown_server() | |
def accept_client(self, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter): | |
''' | |
Callback that is used when server accepts clients | |
Parameters | |
---------- | |
client_reader : asyncio.StreamReader | |
StreamReader generated by asyncio.start_server. | |
client_writer : asyncio.StreamWriter | |
StreamWriter generated by asyncio.start_server. | |
---------- | |
''' | |
client = Client(client_reader, client_writer) | |
task = asyncio.Task(self.incoming_client_message_cb(client)) | |
self.clients[task] = client | |
client_ip = client_writer.get_extra_info('peername')[0] | |
client_port = client_writer.get_extra_info('peername')[1] | |
self.logger.info(f"New Connection: {client_ip}:{client_port}") | |
task.add_done_callback(self.disconnect_client) | |
async def incoming_client_message_cb(self, client: Client): | |
''' | |
Callback for handling incoming messages from client | |
Parameters | |
---------- | |
client : Client | |
---------- | |
''' | |
while True: | |
client_message = await client.get_message() | |
if client_message.startswith("quit"): | |
break | |
elif client_message.startswith("/"): | |
self.handle_client_command(client, client_message) | |
else: | |
self.broadcast_message( | |
f"{client.nickname}: {client_message}".encode('utf8')) | |
self.logger.info(f"{client_message}") | |
await client.writer.drain() | |
self.logger.info("Client Disconnected!") | |
def handle_client_command(self, client: Client, client_message: str): | |
''' | |
Parameters | |
---------- | |
client : Client | |
Client associated with incoming message | |
client_message : str | |
Incoming message from client that will be parsed for any valid commands | |
''' | |
client_message = client_message.replace("\n", "").replace("\r", "") | |
if client_message.startswith("/nick"): | |
split_client_message = client_message.split(" ") | |
if len(split_client_message) >= 2: | |
client.nickname = split_client_message[1] | |
client.writer.write( | |
f"Nickname changed to {client.nickname}\n".encode('utf8')) | |
return | |
client.writer.write("Invalid Command\n".encode('utf8')) | |
def broadcast_message(self, message: bytes, exclusion_list: list = []): | |
''' | |
Parameters | |
---------- | |
message : bytes | |
A message consisting of utf8 bytes to broadcast to all clients. | |
OPTIONAL exclusion_list : list[Client] | |
A list of clients to exclude from receiving the provided message. | |
---------- | |
''' | |
for client in self.clients.values(): | |
if client not in exclusion_list: | |
client.writer.write(message) | |
def disconnect_client(self, task: asyncio.Task): | |
''' | |
Disconnects and broadcasts to the other clients that a client has been disconnected. | |
Parameters | |
---------- | |
task : asyncio.Task | |
The task object associated with the client generated during self.accept_client() | |
---------- | |
''' | |
client = self.clients[task] | |
self.broadcast_message( | |
f"{client.nickname} has left!".encode('utf8'), [client]) | |
del self.clients[task] | |
client.writer.write('quit'.encode('utf8')) | |
client.writer.close() | |
self.logger.info("End Connection") | |
def shutdown_server(self): | |
''' | |
Shuts down server. | |
''' | |
self.logger.info("Shutting down server!") | |
for client in self.clients.values(): | |
client.writer.write('quit'.encode('utf8')) | |
self.loop.stop() | |
if __name__ == '__main__': | |
if len(sys.argv) < 3: | |
sys.exit(f"Usage: {sys.argv[0]} HOST_IP PORT") | |
loop = asyncio.get_event_loop() | |
server = Server(sys.argv[1], sys.argv[2], loop) | |
server.start_server() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment