Skip to content

Instantly share code, notes, and snippets.

@aont
Last active December 3, 2023 11:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aont/c5c797880519e6387a0cf55ce72fb196 to your computer and use it in GitHub Desktop.
Save aont/c5c797880519e6387a0cf55ce72fb196 to your computer and use it in GitHub Desktop.
import sys
import os
import datetime
import argparse
import re
import asyncio
import ipaddress
import socket
try: import tomllib as toml
except ModuleNotFoundError: import tomli as toml
import jsonpath_ng.ext
import boto3
import shtab
class EC2Aux:
def __init__(self):
self.config = self._load_config()
# self.user_name = self.config["default"]["user_name"]
self.ssh_config_path_list = tuple(os.path.expandvars(ssh_config_path) for ssh_config_path in self.config["default"]["ssh_config_path_list"])
def _setup_ec2_client(self):
self.ec2 = boto3.client('ec2')
@staticmethod
def _load_config():
with open(os.path.expanduser("~/.myconfig/ec2aux.toml"), "rt") as fp:
return toml.loads(fp.read())
# return config
def _get_describe_instances(self):
response = self.ec2.describe_instances()
return response
@staticmethod
def _get_owner_id_set(describe_instances):
return set(reservations["OwnerId"] for reservations in describe_instances["Reservations"])
@staticmethod
def _get_instance_info(describe_instances, instance_name):
jsonpath_expr = jsonpath_ng.ext.parse(f"Reservations[*].Instances[?(@.Tags[?(@.Key==Name&@.Value=={instance_name})])]")
instance_info_list = [match.value for match in jsonpath_expr.find(describe_instances)]
if len(instance_info_list)==0:
raise AttributeError(f"instance {instance_name} not found")
elif len(instance_info_list)>1:
raise Exception(f"multiple instances found with name {instance_name}")
return instance_info_list[0]
@staticmethod
def _get_instance_info_list(describe_instances):
jsonpath_expr = jsonpath_ng.ext.parse(f"Reservations[*].Instances[*]")
instance_info_list = [match.value for match in jsonpath_expr.find(describe_instances)]
return instance_info_list
@staticmethod
def _get_tag_dict(instance_info):
return {tags["Key"]: tags["Value"] for tags in instance_info["Tags"]}
def _update_ssh_config_one(self, ssh_config_path, ssh_config_ec2):
if not os.path.isfile(ssh_config_path):
sys.stderr.write(f"debug: {ssh_config_path} does not exist. skipping...\n")
return
with open(ssh_config_path, "rb") as fp:
ssh_config_bytes = fp.read()
begin_aws_ec2_pat = re.compile(b"^# begin aws ec2$", flags=re.MULTILINE)
end_aws_ec2_pat = re.compile(b"^# end aws ec2$", flags=re.MULTILINE)
begin_aws_ec2_match = begin_aws_ec2_pat.search(ssh_config_bytes)
end_aws_ec2_match = end_aws_ec2_pat.search(ssh_config_bytes)
assert begin_aws_ec2_match is not None
assert end_aws_ec2_match is not None
assert ssh_config_bytes[begin_aws_ec2_match.end()]==ord(b"\n")
ssh_config_bytes_mod = ssh_config_bytes[:begin_aws_ec2_match.end()+1] + ssh_config_ec2 + ssh_config_bytes[end_aws_ec2_match.start():]
if ssh_config_bytes_mod == ssh_config_bytes:
sys.stderr.write(f"debug: skip updating {ssh_config_path} since update is not necessary\n")
return
ssh_config_backup_dir = ssh_config_path + "_backup"
os.makedirs(ssh_config_backup_dir, exist_ok=True)
mtime = os.path.getmtime(ssh_config_path)
dt_mtime = datetime.datetime.fromtimestamp(mtime)
dt_mtime_str = dt_mtime.strftime("%Y%m%d_%H%M%S")
ssh_config_path_backup = os.path.join(ssh_config_backup_dir, dt_mtime_str)
os.rename(ssh_config_path, ssh_config_path_backup)
with open(ssh_config_path, "wb") as fp:
fp.write(ssh_config_bytes_mod)
@staticmethod
def _remove_old_backups(ssh_config_path):
ssh_config_backup_dir = ssh_config_path + "_backup"
if os.path.exists(ssh_config_backup_dir):
ssh_config_backup_folder_path_list = [os.path.join(ssh_config_backup_dir, fn) for fn in os.listdir(ssh_config_backup_dir)]
else:
sys.stderr.write(f"debug: {ssh_config_backup_dir} does not exist. \n")
ssh_config_backup_folder_path_list = []
ssh_dirpath = os.path.dirname(ssh_config_path)
if os.path.exists(ssh_dirpath):
ssh_config_backup_path_list = [
filepath
for filepath in [
os.path.join(ssh_dirpath, fn)
for fn in os.listdir(ssh_dirpath)
if fn.startswith("config")
]
if os.path.isfile(filepath)
]
else:
ssh_config_backup_path_list = []
ssh_config_backup_path_list += ssh_config_backup_folder_path_list
ssh_config_backup_path_list_sort = sorted(ssh_config_backup_path_list, key=lambda p: os.path.getmtime(p))
for ssh_config_backup_path in ssh_config_backup_path_list_sort[:-3]:
os.remove(ssh_config_backup_path)
@staticmethod
def _get_owner_id(describe_instances: dict):
owner_id_set = EC2Aux._get_owner_id_set(describe_instances)
assert len(owner_id_set)==1
owner_id = owner_id_set.pop()
return owner_id
def update_ssh_config(self):
self._setup_ec2_client()
sys.stderr.write(f"info: acquire ec2 info\n")
describe_instances = self._get_describe_instances()
instance_info_list = self._get_instance_info_list(describe_instances)
ssh_config_ec2_list = []
for instance_info in instance_info_list:
tag_dict = self._get_tag_dict(instance_info)
instance_name = tag_dict.get("Name")
if instance_name is None:
sys.stderr.write(f"warn: Name is not set. skipping...")
continue
sys.stderr.write(f"debug: {instance_name=}\n")
user_name = tag_dict.get("User")
if instance_name is None:
sys.stderr.write(f"warn: User is not set.")
else:
sys.stderr.write(f"debug: {user_name=}\n")
instance_id = instance_info["InstanceId"]
sys.stderr.write(f"debug: {instance_id=}\n")
# instance_address = instance_info["PublicDnsName"]
instance_address = instance_info.get("PublicIpAddress")
# if instance_address == "":
if instance_address is None:
instance_address = "0.0.0.0"
sys.stderr.write(f"warn: address is not assigned.\n")
# continue
else:
sys.stderr.write(f"debug: {instance_address=}\n")
sys.stderr.write(f"info: update ssh config\n")
ssh_config_ec2_i = ""
ssh_config_ec2_i += "Host " + instance_name + "\n"
if user_name is not None:
ssh_config_ec2_i += "User " + user_name + "\n"
ssh_config_ec2_i += "HostName " + instance_address + "\n"
ssh_config_ec2_list.append(ssh_config_ec2_i)
ssh_config_ec2 = "\n\n".join(ssh_config_ec2_list).encode()
for ssh_config_path in self.ssh_config_path_list:
sys.stderr.write(f"debug: {ssh_config_path=}\n")
self._update_ssh_config_one(ssh_config_path, ssh_config_ec2)
sys.stderr.write(f"info: remove old backups\n")
for ssh_config_path in self.ssh_config_path_list:
sys.stderr.write(f"debug: {ssh_config_path=}\n")
self._remove_old_backups(ssh_config_path)
self.remove_ec2_hosts_from_known_hosts()
def start_or_stop_instance(self, instance_name=None, start_or_stop=True, wait_complete=False):
self._setup_ec2_client()
sys.stderr.write(f"info: acquire ec2 info\n")
describe_instances = self._get_describe_instances()
instance_info = self._get_instance_info(describe_instances, instance_name)
instance_id = instance_info["InstanceId"]
sys.stderr.write(f"info: {'start' if start_or_stop else 'stop'} instance\n")
start_or_stop_func = self.ec2.start_instances if start_or_stop else self.ec2.stop_instances
awscli_obj = start_or_stop_func(InstanceIds=[instance_id])
str_starting_or_stopping_instances = 'StartingInstances' if start_or_stop else 'StoppingInstances'
starting_or_stopping_instances = awscli_obj[str_starting_or_stopping_instances]
if start_or_stop: # start
code_list = (
0, # pending
16, # running
)
else: # stop
code_list = (
64, # stopping
80, # stopped
)
try:
assert len(starting_or_stopping_instances)==1
assert starting_or_stopping_instances[0]["CurrentState"]["Code"] in code_list
except AssertionError as e:
raise Exception(awscli_obj) from e
if wait_complete:
self._wait_boot_complete(instance_id)
@staticmethod
async def _serial_ping_stdin(serial_stdin: asyncio.StreamWriter, stop_ping_event: asyncio.Event):
stop_event = asyncio.ensure_future(stop_ping_event.wait())
while True:
sys.stderr.write(f"info: ping\n")
serial_stdin.write(b"\n")
# proc_stdin.write(b"")
# sys.stderr.write(f"info: drain\n")
await serial_stdin.drain()
# try:
# sys.stderr.write(f"info: wait stop event\n")
done, pending = await asyncio.wait((stop_event,), timeout=1)
if len(done)==0:
continue
elif len(done)==1 and tuple(done)[0]==stop_event:
break
else:
sys.stderr.write(f"debug: wait return {done=}\n")
raise Exception()
@staticmethod
async def _serial_read_stdout(serial_stdout: asyncio.StreamReader):
pat = re.compile("^ip-\\d+-\\d+-\\d+-\\d+ login:\\s+$")
while True:
line: bytes = await serial_stdout.readline()
if line==b'':
sys.stderr.write(f"debug: reading ssh stdout finished unexpectedly\n")
raise Exception()
line = line.rstrip(b"\r\n").decode(errors="ignore")
if len(line)==0:
continue
sys.stderr.write(f"debug: ssh: {line=}\n")
match = pat.match(line)
if match:
break
def _setup_ec2_ic(self):
self.ec2icc = boto3.client('ec2-instance-connect')
async def _wait_boot_complete_async(self, instance_id):
self._setup_ec2_ic()
ssh_pubkey_path = os.path.expanduser("~/.ssh/id_rsa.pub")
with open(ssh_pubkey_path, "rt") as f:
ssh_pubkey = f.read()
e_prev = None
while True:
try:
response = self.ec2icc.send_serial_console_ssh_public_key(
InstanceId=instance_id,
SerialPort=0,
SSHPublicKey=ssh_pubkey
)
assert response["Success"]==True
break
except AssertionError as e:
sys.stderr.write(f"info: {response=}")
await asyncio.sleep(1)
except Exception as e:
if not isinstance(e, type(e_prev)):
sys.stderr.write(f"warn: exception occured. retrying... {e=}\n")
e_prev = e
await asyncio.sleep(1)
ssh_proc = await asyncio.subprocess.create_subprocess_exec("ssh", "-tt", "-l", f"{instance_id}.port0", "serial-console.ec2-instance-connect.ap-northeast-1.aws", stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, )
stop_ping_event = asyncio.Event()
ping_stdin_task = asyncio.create_task(EC2Aux._serial_ping_stdin(ssh_proc.stdin, stop_ping_event))
read_stdout_task = asyncio.create_task(EC2Aux._serial_read_stdout(ssh_proc.stdout))
await read_stdout_task
stop_ping_event.set()
await ping_stdin_task
# ping_stdin_task.cancel()
ssh_proc.stdin.write(b"~.")
ssh_proc.kill()
ret_ssh = await ssh_proc.wait()
sys.stderr.write(f"debug: {ret_ssh=}\n")
def _wait_boot_complete(self, instance_id):
return asyncio.run(self._wait_boot_complete_async(instance_id))
def remove_ec2_hosts_from_known_hosts(self):
for ssh_config_path in self.ssh_config_path_list:
ssh_config_dirpath = os.path.dirname(ssh_config_path)
known_hosts_path = os.path.join(ssh_config_dirpath, "known_hosts")
sys.stderr.write(f"info: remove ec2 hosts from {known_hosts_path}\n")
known_hosts_mod_lines = []
has_change = False
with open(known_hosts_path, "rt") as fp:
while True:
known_hosts_line = fp.readline()
if known_hosts_line=="":
break
# known_hosts_line = known_hosts_line.rstrip()
known_hosts_line_split = known_hosts_line.split(" ")
address = known_hosts_line_split[0]
try:
ipaddress.ip_address(address)
is_ipaddr = True
except ValueError:
is_ipaddr = False
fqdn_address = address
if is_ipaddr:
fqdn_address = socket.gethostbyaddr(address)[0]
if fqdn_address.startswith("ec2-") and fqdn_address.endswith(".compute.amazonaws.com"):
sys.stderr.write(f"info: {fqdn_address} is ec2 address. skipping...\n")
has_change = True
continue
known_hosts_mod_lines.append(known_hosts_line)
# sys.stdout.write(known_hosts_line+"\n")
if not has_change:
sys.stderr.write(f"info: no ec2 address is found.\n")
return
# mtime = os.path.getmtime(known_hosts_path)
# dt_mtime = datetime.datetime.fromtimestamp(mtime)
# dt_mtime_str = dt_mtime.strftime("%Y%m%d_%H%M%S")
known_hosts_backup_path = os.path.join(ssh_config_dirpath, "known_hosts.old")
# os.path.expanduser(f"~/.ssh/known_hosts.old")
if os.path.exists(known_hosts_backup_path):
os.remove(known_hosts_backup_path)
os.rename(known_hosts_path, known_hosts_backup_path)
with open(known_hosts_path, "wt") as fp:
for known_hosts_mod_line in known_hosts_mod_lines:
fp.write(known_hosts_mod_line)
def main():
parser = argparse.ArgumentParser(add_help=True)
shtab.add_argument_to(parser, ["-s", "--print-completion"])
subparsers = parser.add_subparsers(title="command", dest="command",)
parser_start_instance = subparsers.add_parser("start-instance", help="start instance")
parser_start_instance.add_argument("instance_name")
parser_start_instance.add_argument("--wait-complete", dest="wait_complete", action="store_true", help="wait boot completion")
parser_start_instance = subparsers.add_parser("stop-instance", help="stop instance")
parser_start_instance.add_argument("instance_name")
parser_start_instance = subparsers.add_parser("update-ssh-config", help="update ssh config")
args = parser.parse_args()
if args.command == "start-instance":
return EC2Aux().start_or_stop_instance(args.instance_name, True, wait_complete=args.wait_complete)
if args.command == "stop-instance":
return EC2Aux().start_or_stop_instance(args.instance_name, False)
if args.command == "update-ssh-config":
return EC2Aux().update_ssh_config()
if __name__ == '__main__':
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment