Skip to content

Instantly share code, notes, and snippets.

@jplock
Created March 31, 2021 02:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jplock/e783458f2163ab4fa5a75bdc10108428 to your computer and use it in GitHub Desktop.
Save jplock/e783458f2163ab4fa5a75bdc10108428 to your computer and use it in GitHub Desktop.
AWS Batch Multi-node Parallel Job sample code
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import socket
import socketserver
import sys
import time
import threading
PORT = 8080
NUM_NODES = int(os.environ["AWS_BATCH_JOB_NUM_NODES"])
NODES = {}
class RequestHandler(socketserver.BaseRequestHandler):
def handle(self):
data = str(self.request.recv(1024), "ascii")
command, node_index = data.split("=", 2)
if command == "CHECKIN":
print(f"Node {node_index} checked in")
response = "OK"
NODES[node_index] = response
self.request.sendall(response.encode())
class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
def is_main_node(node_index):
return int(os.environ["AWS_BATCH_JOB_MAIN_NODE_INDEX"]) == node_index
def printenv():
print("Printing environment variables:")
for k, v in os.environ.items():
print(f"{k}={v}")
def start_server(node_index):
# Check in the main node
NODES[node_index] = "OK"
server = ThreadedTCPServer(("0.0.0.0", PORT), RequestHandler)
with server:
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
print(f"Server loop running in thread: {server_thread.name}", flush=True)
while True:
node_count = len(NODES)
print(f"{node_count}/{NUM_NODES} checked in", flush=True)
if node_count == NUM_NODES:
print(
f"All {node_count} nodes have checked in, shutting down", flush=True
)
server.shutdown()
return
print("Sleeping for 10 seconds...", flush=True)
time.sleep(10)
def start_client(node_index):
main_ip = os.environ["AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((main_ip, PORT))
sock.sendall(bytes(f"CHECKIN={node_index}", "ascii"))
response = str(sock.recv(1024), "ascii")
print(f"Received: {response}")
def main():
node_index = int(os.environ["AWS_BATCH_JOB_NODE_INDEX"])
printenv()
if is_main_node(node_index):
print("I am the main node")
start_server(node_index)
print("Main node finished")
else:
print(f"I am node {node_index}")
start_client(node_index)
print(f"Node {node_index} finished")
sys.exit(0)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment