Last active
December 3, 2023 11:55
-
-
Save aont/c5c797880519e6387a0cf55ce72fb196 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import datetime | |
import argparse | |
import re | |
import asyncio | |
import ipaddress | |
import socket | |
try: import tomllib as toml | |
except ModuleNotFoundError: import tomli as toml | |
import jsonpath_ng.ext | |
import boto3 | |
import shtab | |
class EC2Aux: | |
def __init__(self): | |
self.config = self._load_config() | |
# self.user_name = self.config["default"]["user_name"] | |
self.ssh_config_path_list = tuple(os.path.expandvars(ssh_config_path) for ssh_config_path in self.config["default"]["ssh_config_path_list"]) | |
def _setup_ec2_client(self): | |
self.ec2 = boto3.client('ec2') | |
@staticmethod | |
def _load_config(): | |
with open(os.path.expanduser("~/.myconfig/ec2aux.toml"), "rt") as fp: | |
return toml.loads(fp.read()) | |
# return config | |
def _get_describe_instances(self): | |
response = self.ec2.describe_instances() | |
return response | |
@staticmethod | |
def _get_owner_id_set(describe_instances): | |
return set(reservations["OwnerId"] for reservations in describe_instances["Reservations"]) | |
@staticmethod | |
def _get_instance_info(describe_instances, instance_name): | |
jsonpath_expr = jsonpath_ng.ext.parse(f"Reservations[*].Instances[?(@.Tags[?(@.Key==Name&@.Value=={instance_name})])]") | |
instance_info_list = [match.value for match in jsonpath_expr.find(describe_instances)] | |
if len(instance_info_list)==0: | |
raise AttributeError(f"instance {instance_name} not found") | |
elif len(instance_info_list)>1: | |
raise Exception(f"multiple instances found with name {instance_name}") | |
return instance_info_list[0] | |
@staticmethod | |
def _get_instance_info_list(describe_instances): | |
jsonpath_expr = jsonpath_ng.ext.parse(f"Reservations[*].Instances[*]") | |
instance_info_list = [match.value for match in jsonpath_expr.find(describe_instances)] | |
return instance_info_list | |
@staticmethod | |
def _get_tag_dict(instance_info): | |
return {tags["Key"]: tags["Value"] for tags in instance_info["Tags"]} | |
def _update_ssh_config_one(self, ssh_config_path, ssh_config_ec2): | |
if not os.path.isfile(ssh_config_path): | |
sys.stderr.write(f"debug: {ssh_config_path} does not exist. skipping...\n") | |
return | |
with open(ssh_config_path, "rb") as fp: | |
ssh_config_bytes = fp.read() | |
begin_aws_ec2_pat = re.compile(b"^# begin aws ec2$", flags=re.MULTILINE) | |
end_aws_ec2_pat = re.compile(b"^# end aws ec2$", flags=re.MULTILINE) | |
begin_aws_ec2_match = begin_aws_ec2_pat.search(ssh_config_bytes) | |
end_aws_ec2_match = end_aws_ec2_pat.search(ssh_config_bytes) | |
assert begin_aws_ec2_match is not None | |
assert end_aws_ec2_match is not None | |
assert ssh_config_bytes[begin_aws_ec2_match.end()]==ord(b"\n") | |
ssh_config_bytes_mod = ssh_config_bytes[:begin_aws_ec2_match.end()+1] + ssh_config_ec2 + ssh_config_bytes[end_aws_ec2_match.start():] | |
if ssh_config_bytes_mod == ssh_config_bytes: | |
sys.stderr.write(f"debug: skip updating {ssh_config_path} since update is not necessary\n") | |
return | |
ssh_config_backup_dir = ssh_config_path + "_backup" | |
os.makedirs(ssh_config_backup_dir, exist_ok=True) | |
mtime = os.path.getmtime(ssh_config_path) | |
dt_mtime = datetime.datetime.fromtimestamp(mtime) | |
dt_mtime_str = dt_mtime.strftime("%Y%m%d_%H%M%S") | |
ssh_config_path_backup = os.path.join(ssh_config_backup_dir, dt_mtime_str) | |
os.rename(ssh_config_path, ssh_config_path_backup) | |
with open(ssh_config_path, "wb") as fp: | |
fp.write(ssh_config_bytes_mod) | |
@staticmethod | |
def _remove_old_backups(ssh_config_path): | |
ssh_config_backup_dir = ssh_config_path + "_backup" | |
if os.path.exists(ssh_config_backup_dir): | |
ssh_config_backup_folder_path_list = [os.path.join(ssh_config_backup_dir, fn) for fn in os.listdir(ssh_config_backup_dir)] | |
else: | |
sys.stderr.write(f"debug: {ssh_config_backup_dir} does not exist. \n") | |
ssh_config_backup_folder_path_list = [] | |
ssh_dirpath = os.path.dirname(ssh_config_path) | |
if os.path.exists(ssh_dirpath): | |
ssh_config_backup_path_list = [ | |
filepath | |
for filepath in [ | |
os.path.join(ssh_dirpath, fn) | |
for fn in os.listdir(ssh_dirpath) | |
if fn.startswith("config") | |
] | |
if os.path.isfile(filepath) | |
] | |
else: | |
ssh_config_backup_path_list = [] | |
ssh_config_backup_path_list += ssh_config_backup_folder_path_list | |
ssh_config_backup_path_list_sort = sorted(ssh_config_backup_path_list, key=lambda p: os.path.getmtime(p)) | |
for ssh_config_backup_path in ssh_config_backup_path_list_sort[:-3]: | |
os.remove(ssh_config_backup_path) | |
@staticmethod | |
def _get_owner_id(describe_instances: dict): | |
owner_id_set = EC2Aux._get_owner_id_set(describe_instances) | |
assert len(owner_id_set)==1 | |
owner_id = owner_id_set.pop() | |
return owner_id | |
def update_ssh_config(self): | |
self._setup_ec2_client() | |
sys.stderr.write(f"info: acquire ec2 info\n") | |
describe_instances = self._get_describe_instances() | |
instance_info_list = self._get_instance_info_list(describe_instances) | |
ssh_config_ec2_list = [] | |
for instance_info in instance_info_list: | |
tag_dict = self._get_tag_dict(instance_info) | |
instance_name = tag_dict.get("Name") | |
if instance_name is None: | |
sys.stderr.write(f"warn: Name is not set. skipping...") | |
continue | |
sys.stderr.write(f"debug: {instance_name=}\n") | |
user_name = tag_dict.get("User") | |
if instance_name is None: | |
sys.stderr.write(f"warn: User is not set.") | |
else: | |
sys.stderr.write(f"debug: {user_name=}\n") | |
instance_id = instance_info["InstanceId"] | |
sys.stderr.write(f"debug: {instance_id=}\n") | |
# instance_address = instance_info["PublicDnsName"] | |
instance_address = instance_info.get("PublicIpAddress") | |
# if instance_address == "": | |
if instance_address is None: | |
instance_address = "0.0.0.0" | |
sys.stderr.write(f"warn: address is not assigned.\n") | |
# continue | |
else: | |
sys.stderr.write(f"debug: {instance_address=}\n") | |
sys.stderr.write(f"info: update ssh config\n") | |
ssh_config_ec2_i = "" | |
ssh_config_ec2_i += "Host " + instance_name + "\n" | |
if user_name is not None: | |
ssh_config_ec2_i += "User " + user_name + "\n" | |
ssh_config_ec2_i += "HostName " + instance_address + "\n" | |
ssh_config_ec2_list.append(ssh_config_ec2_i) | |
ssh_config_ec2 = "\n\n".join(ssh_config_ec2_list).encode() | |
for ssh_config_path in self.ssh_config_path_list: | |
sys.stderr.write(f"debug: {ssh_config_path=}\n") | |
self._update_ssh_config_one(ssh_config_path, ssh_config_ec2) | |
sys.stderr.write(f"info: remove old backups\n") | |
for ssh_config_path in self.ssh_config_path_list: | |
sys.stderr.write(f"debug: {ssh_config_path=}\n") | |
self._remove_old_backups(ssh_config_path) | |
self.remove_ec2_hosts_from_known_hosts() | |
def start_or_stop_instance(self, instance_name=None, start_or_stop=True, wait_complete=False): | |
self._setup_ec2_client() | |
sys.stderr.write(f"info: acquire ec2 info\n") | |
describe_instances = self._get_describe_instances() | |
instance_info = self._get_instance_info(describe_instances, instance_name) | |
instance_id = instance_info["InstanceId"] | |
sys.stderr.write(f"info: {'start' if start_or_stop else 'stop'} instance\n") | |
start_or_stop_func = self.ec2.start_instances if start_or_stop else self.ec2.stop_instances | |
awscli_obj = start_or_stop_func(InstanceIds=[instance_id]) | |
str_starting_or_stopping_instances = 'StartingInstances' if start_or_stop else 'StoppingInstances' | |
starting_or_stopping_instances = awscli_obj[str_starting_or_stopping_instances] | |
if start_or_stop: # start | |
code_list = ( | |
0, # pending | |
16, # running | |
) | |
else: # stop | |
code_list = ( | |
64, # stopping | |
80, # stopped | |
) | |
try: | |
assert len(starting_or_stopping_instances)==1 | |
assert starting_or_stopping_instances[0]["CurrentState"]["Code"] in code_list | |
except AssertionError as e: | |
raise Exception(awscli_obj) from e | |
if wait_complete: | |
self._wait_boot_complete(instance_id) | |
@staticmethod | |
async def _serial_ping_stdin(serial_stdin: asyncio.StreamWriter, stop_ping_event: asyncio.Event): | |
stop_event = asyncio.ensure_future(stop_ping_event.wait()) | |
while True: | |
sys.stderr.write(f"info: ping\n") | |
serial_stdin.write(b"\n") | |
# proc_stdin.write(b"") | |
# sys.stderr.write(f"info: drain\n") | |
await serial_stdin.drain() | |
# try: | |
# sys.stderr.write(f"info: wait stop event\n") | |
done, pending = await asyncio.wait((stop_event,), timeout=1) | |
if len(done)==0: | |
continue | |
elif len(done)==1 and tuple(done)[0]==stop_event: | |
break | |
else: | |
sys.stderr.write(f"debug: wait return {done=}\n") | |
raise Exception() | |
@staticmethod | |
async def _serial_read_stdout(serial_stdout: asyncio.StreamReader): | |
pat = re.compile("^ip-\\d+-\\d+-\\d+-\\d+ login:\\s+$") | |
while True: | |
line: bytes = await serial_stdout.readline() | |
if line==b'': | |
sys.stderr.write(f"debug: reading ssh stdout finished unexpectedly\n") | |
raise Exception() | |
line = line.rstrip(b"\r\n").decode(errors="ignore") | |
if len(line)==0: | |
continue | |
sys.stderr.write(f"debug: ssh: {line=}\n") | |
match = pat.match(line) | |
if match: | |
break | |
def _setup_ec2_ic(self): | |
self.ec2icc = boto3.client('ec2-instance-connect') | |
async def _wait_boot_complete_async(self, instance_id): | |
self._setup_ec2_ic() | |
ssh_pubkey_path = os.path.expanduser("~/.ssh/id_rsa.pub") | |
with open(ssh_pubkey_path, "rt") as f: | |
ssh_pubkey = f.read() | |
e_prev = None | |
while True: | |
try: | |
response = self.ec2icc.send_serial_console_ssh_public_key( | |
InstanceId=instance_id, | |
SerialPort=0, | |
SSHPublicKey=ssh_pubkey | |
) | |
assert response["Success"]==True | |
break | |
except AssertionError as e: | |
sys.stderr.write(f"info: {response=}") | |
await asyncio.sleep(1) | |
except Exception as e: | |
if not isinstance(e, type(e_prev)): | |
sys.stderr.write(f"warn: exception occured. retrying... {e=}\n") | |
e_prev = e | |
await asyncio.sleep(1) | |
ssh_proc = await asyncio.subprocess.create_subprocess_exec("ssh", "-tt", "-l", f"{instance_id}.port0", "serial-console.ec2-instance-connect.ap-northeast-1.aws", stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, ) | |
stop_ping_event = asyncio.Event() | |
ping_stdin_task = asyncio.create_task(EC2Aux._serial_ping_stdin(ssh_proc.stdin, stop_ping_event)) | |
read_stdout_task = asyncio.create_task(EC2Aux._serial_read_stdout(ssh_proc.stdout)) | |
await read_stdout_task | |
stop_ping_event.set() | |
await ping_stdin_task | |
# ping_stdin_task.cancel() | |
ssh_proc.stdin.write(b"~.") | |
ssh_proc.kill() | |
ret_ssh = await ssh_proc.wait() | |
sys.stderr.write(f"debug: {ret_ssh=}\n") | |
def _wait_boot_complete(self, instance_id): | |
return asyncio.run(self._wait_boot_complete_async(instance_id)) | |
def remove_ec2_hosts_from_known_hosts(self): | |
for ssh_config_path in self.ssh_config_path_list: | |
ssh_config_dirpath = os.path.dirname(ssh_config_path) | |
known_hosts_path = os.path.join(ssh_config_dirpath, "known_hosts") | |
sys.stderr.write(f"info: remove ec2 hosts from {known_hosts_path}\n") | |
known_hosts_mod_lines = [] | |
has_change = False | |
with open(known_hosts_path, "rt") as fp: | |
while True: | |
known_hosts_line = fp.readline() | |
if known_hosts_line=="": | |
break | |
# known_hosts_line = known_hosts_line.rstrip() | |
known_hosts_line_split = known_hosts_line.split(" ") | |
address = known_hosts_line_split[0] | |
try: | |
ipaddress.ip_address(address) | |
is_ipaddr = True | |
except ValueError: | |
is_ipaddr = False | |
fqdn_address = address | |
if is_ipaddr: | |
fqdn_address = socket.gethostbyaddr(address)[0] | |
if fqdn_address.startswith("ec2-") and fqdn_address.endswith(".compute.amazonaws.com"): | |
sys.stderr.write(f"info: {fqdn_address} is ec2 address. skipping...\n") | |
has_change = True | |
continue | |
known_hosts_mod_lines.append(known_hosts_line) | |
# sys.stdout.write(known_hosts_line+"\n") | |
if not has_change: | |
sys.stderr.write(f"info: no ec2 address is found.\n") | |
return | |
# mtime = os.path.getmtime(known_hosts_path) | |
# dt_mtime = datetime.datetime.fromtimestamp(mtime) | |
# dt_mtime_str = dt_mtime.strftime("%Y%m%d_%H%M%S") | |
known_hosts_backup_path = os.path.join(ssh_config_dirpath, "known_hosts.old") | |
# os.path.expanduser(f"~/.ssh/known_hosts.old") | |
if os.path.exists(known_hosts_backup_path): | |
os.remove(known_hosts_backup_path) | |
os.rename(known_hosts_path, known_hosts_backup_path) | |
with open(known_hosts_path, "wt") as fp: | |
for known_hosts_mod_line in known_hosts_mod_lines: | |
fp.write(known_hosts_mod_line) | |
def main(): | |
parser = argparse.ArgumentParser(add_help=True) | |
shtab.add_argument_to(parser, ["-s", "--print-completion"]) | |
subparsers = parser.add_subparsers(title="command", dest="command",) | |
parser_start_instance = subparsers.add_parser("start-instance", help="start instance") | |
parser_start_instance.add_argument("instance_name") | |
parser_start_instance.add_argument("--wait-complete", dest="wait_complete", action="store_true", help="wait boot completion") | |
parser_start_instance = subparsers.add_parser("stop-instance", help="stop instance") | |
parser_start_instance.add_argument("instance_name") | |
parser_start_instance = subparsers.add_parser("update-ssh-config", help="update ssh config") | |
args = parser.parse_args() | |
if args.command == "start-instance": | |
return EC2Aux().start_or_stop_instance(args.instance_name, True, wait_complete=args.wait_complete) | |
if args.command == "stop-instance": | |
return EC2Aux().start_or_stop_instance(args.instance_name, False) | |
if args.command == "update-ssh-config": | |
return EC2Aux().update_ssh_config() | |
if __name__ == '__main__': | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment