Created
March 31, 2021 02:12
-
-
Save jplock/e783458f2163ab4fa5a75bdc10108428 to your computer and use it in GitHub Desktop.
AWS Batch Multi-node Parallel Job sample code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import os | |
import socket | |
import socketserver | |
import sys | |
import time | |
import threading | |
PORT = 8080 | |
NUM_NODES = int(os.environ["AWS_BATCH_JOB_NUM_NODES"]) | |
NODES = {} | |
class RequestHandler(socketserver.BaseRequestHandler): | |
def handle(self): | |
data = str(self.request.recv(1024), "ascii") | |
command, node_index = data.split("=", 2) | |
if command == "CHECKIN": | |
print(f"Node {node_index} checked in") | |
response = "OK" | |
NODES[node_index] = response | |
self.request.sendall(response.encode()) | |
class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): | |
pass | |
def is_main_node(node_index): | |
return int(os.environ["AWS_BATCH_JOB_MAIN_NODE_INDEX"]) == node_index | |
def printenv(): | |
print("Printing environment variables:") | |
for k, v in os.environ.items(): | |
print(f"{k}={v}") | |
def start_server(node_index): | |
# Check in the main node | |
NODES[node_index] = "OK" | |
server = ThreadedTCPServer(("0.0.0.0", PORT), RequestHandler) | |
with server: | |
server_thread = threading.Thread(target=server.serve_forever) | |
server_thread.daemon = True | |
server_thread.start() | |
print(f"Server loop running in thread: {server_thread.name}", flush=True) | |
while True: | |
node_count = len(NODES) | |
print(f"{node_count}/{NUM_NODES} checked in", flush=True) | |
if node_count == NUM_NODES: | |
print( | |
f"All {node_count} nodes have checked in, shutting down", flush=True | |
) | |
server.shutdown() | |
return | |
print("Sleeping for 10 seconds...", flush=True) | |
time.sleep(10) | |
def start_client(node_index): | |
main_ip = os.environ["AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"] | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: | |
sock.connect((main_ip, PORT)) | |
sock.sendall(bytes(f"CHECKIN={node_index}", "ascii")) | |
response = str(sock.recv(1024), "ascii") | |
print(f"Received: {response}") | |
def main(): | |
node_index = int(os.environ["AWS_BATCH_JOB_NODE_INDEX"]) | |
printenv() | |
if is_main_node(node_index): | |
print("I am the main node") | |
start_server(node_index) | |
print("Main node finished") | |
else: | |
print(f"I am node {node_index}") | |
start_client(node_index) | |
print(f"Node {node_index} finished") | |
sys.exit(0) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment