Last active
February 12, 2024 17:12
-
-
Save alukach/664ef511e04b668dc6e3c9d73983786c to your computer and use it in GitHub Desktop.
AWS Helper Scripts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script provides a CLI to select an AWS ECS Service and multiple RDS Instances | |
and makes the required Security Group edits to allow the ECS Service to make network | |
connections to the RDS Instances | |
""" | |
from typing import List, Dict | |
import boto3 | |
from botocore.exceptions import ClientError | |
import inquirer | |
def list_ecs_clusters(): | |
ecs = boto3.client("ecs") | |
clusters = ecs.list_clusters() | |
cluster_arns = clusters["clusterArns"] | |
return cluster_arns | |
def list_ecs_services(cluster) -> List[str]: | |
ecs = boto3.client("ecs") | |
services = ecs.list_services(cluster=cluster) | |
return services["serviceArns"] | |
def list_rds_instances() -> Dict[str, str]: | |
rds = boto3.client("rds") | |
return rds.describe_db_instances()["DBInstances"] | |
def get_security_group_from_ecs_service(cluster, service_arn) -> str: | |
ecs = boto3.client("ecs") | |
details = ecs.describe_services(cluster=cluster, services=[service_arn]) | |
service = details["services"][0] | |
deployment = service["deployments"][0] | |
groups = deployment["networkConfiguration"]["awsvpcConfiguration"]["securityGroups"] | |
return groups[0] | |
def get_security_group_ids_for_rds_instance(instance_identifier) -> List[str]: | |
""" | |
Returns the security group IDs for a given RDS instance identifier. | |
:param instance_identifier: The identifier of the RDS instance | |
:return: A list of security group IDs associated with the RDS instance | |
""" | |
rds = boto3.client("rds") | |
try: | |
response = rds.describe_db_instances(DBInstanceIdentifier=instance_identifier) | |
db_instances = response["DBInstances"] | |
if db_instances: | |
# Assuming each instance has at least one security group associated with it | |
security_groups = db_instances[0]["VpcSecurityGroups"] | |
security_group_ids = [sg["VpcSecurityGroupId"] for sg in security_groups] | |
return security_group_ids | |
else: | |
return [] | |
except Exception as e: | |
print( | |
f"Error fetching security group IDs for RDS instance '{instance_identifier}': {e}" | |
) | |
return [] | |
def modify_security_group_rules( | |
security_group_id, | |
protocol, | |
from_port, | |
to_port, | |
source_security_group_id: str, | |
description: str, | |
dry_run: bool, | |
) -> None: | |
ec2 = boto3.client("ec2") | |
if dry_run: | |
print( | |
f"Dry run: Would update Security Group {security_group_id!r} " | |
f"to allow connections from {source_security_group_id!r}" | |
) | |
return | |
try: | |
ec2.authorize_security_group_ingress( | |
GroupId=security_group_id, | |
IpPermissions=[ | |
{ | |
"IpProtocol": protocol, | |
"FromPort": from_port, | |
"ToPort": to_port, | |
"UserIdGroupPairs": [ | |
{ | |
"GroupId": source_security_group_id, | |
"Description": description, | |
} | |
], | |
} | |
], | |
) | |
print(f"Security Group {security_group_id} updated successfully.") | |
except ClientError as e: | |
print(f"Error updating Security Group: {e}") | |
if __name__ == "__main__": | |
dry_run = inquirer.confirm( | |
message="Dry run (no changes will be made)?", | |
default=True, | |
) | |
ecs_clusters = list_ecs_clusters() | |
rds_instances = list_rds_instances() | |
# Select ECS Cluster | |
resource_questions = [ | |
inquirer.List( | |
"cluster", | |
message="Select ECS Cluster that contains the service that requires databases access", | |
choices=ecs_clusters, | |
default=lambda answers: next( | |
(cluster for cluster in ecs_clusters if "grafana" in cluster), None | |
), | |
), | |
inquirer.List( | |
"service", | |
message="Select ECS Service in {cluster} that requires database access", | |
choices=lambda answers: list_ecs_services(answers["cluster"]), | |
), | |
inquirer.Checkbox( | |
"rds_instances", | |
message="Select RDS Instances that will be accessed via {cluster}/{service}", | |
choices=[instance["DBInstanceIdentifier"] for instance in rds_instances], | |
), | |
inquirer.Text( | |
"description", | |
message="Provide a description for the connection", | |
default=lambda answers: f"Allow connections from ECS service: {answers['service'].split(':')[-1]}", | |
), | |
] | |
resources = inquirer.prompt(resource_questions) | |
# Get Security Group of selected RDS Instances | |
rds_security_groups = [ | |
get_security_group_ids_for_rds_instance(rds_instance)[0] | |
for rds_instance in resources["rds_instances"] | |
] | |
# Get Security Group of selected ECS Service | |
ecs_security_group = get_security_group_from_ecs_service( | |
resources["cluster"], resources["service"] | |
) | |
# Update RDS Instances' Security Groups to allow inbound connections from ECS Service | |
for rds_sg_id in rds_security_groups: | |
modify_security_group_rules( | |
security_group_id=rds_sg_id, | |
source_security_group_id=ecs_security_group, | |
protocol="tcp", | |
to_port=5432, | |
from_port=5432, | |
description=resources["description"], | |
dry_run=dry_run, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script allows you to select an AWS SecretsManager Secret and attempts to generate | |
a Postgresql connection string from that secret. | |
""" | |
import json | |
import sys | |
from blessed import Terminal | |
import boto3 | |
import inquirer, inquirer.render.console | |
class StdErrRenderer(inquirer.render.ConsoleRender): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.terminal = Terminal(stream=sys.stderr) | |
def render(self, question, answers=None): | |
question.answers = answers or {} | |
if question.ignore: | |
return question.default | |
clazz = self.render_factory(question.kind) | |
render = clazz( | |
question, | |
terminal=self.terminal, | |
theme=self._theme, | |
show_default=question.show_default, | |
) | |
self.clear_eos() | |
try: | |
return self._event_loop(render) | |
finally: | |
print("", file=self.terminal.stream) | |
def _relocate(self): | |
print(self._position * self.terminal.move_up, end="", file=self.terminal.stream) | |
self._force_initial_column() | |
self._position = 0 | |
def _go_to_end(self, render): | |
positions = len(list(render.get_options())) - self._position | |
if positions > 0: | |
print( | |
self._position * self.terminal.move_down, | |
end="", | |
file=self.terminal.stream, | |
) | |
self._position = 0 | |
def print_str(self, base, lf=False, **kwargs): | |
if lf: | |
self._position += 1 | |
print( | |
base.format(t=self.terminal, **kwargs), | |
end="\n" if lf else "", | |
flush=True, | |
file=sys.stderr, | |
) | |
def clear_eos(self): | |
print(self.terminal.clear_eos(), end="", file=self.terminal.stream) | |
def fetch_aws_secrets(): | |
"""Fetches a list of secret names from AWS Secrets Manager.""" | |
client = boto3.client("secretsmanager") | |
paginator = client.get_paginator("list_secrets") | |
for page in paginator.paginate(): | |
for secret in page["SecretList"]: | |
name = secret["Name"] | |
if any(text in name.lower() for text in ["db", "rds", "database"]): | |
yield secret["Name"] | |
def get_secret_value(secret_name): | |
"""Retrieves the value of the selected secret from AWS Secrets Manager.""" | |
client = boto3.client("secretsmanager") | |
response = client.get_secret_value(SecretId=secret_name) | |
secret = response["SecretString"] | |
return json.loads(secret) | |
def format_postgres_connection_string(secret_details): | |
"""Formats the secret details into a Postgres connection string.""" | |
host = secret_details["host"] | |
user = secret_details["username"] | |
password = secret_details["password"] | |
port = secret_details["port"] | |
dbname = secret_details.get("dbname") | |
return ( | |
f"postgresql://{user}:{password}@{host}:{port}{'/' + dbname if dbname else ''}" | |
) | |
def main(): | |
print("Fetching AWS Secrets...", file=sys.stderr) | |
secrets = fetch_aws_secrets() | |
questions = [ | |
inquirer.List( | |
"secret", | |
message="Select an AWS Secret containing Postgres connection info", | |
choices=[secret for secret in secrets], | |
), | |
] | |
answers = inquirer.prompt(questions, render=StdErrRenderer()) | |
secret_details = get_secret_value(answers["secret"]) | |
if not secret_details: | |
print("Failed to retrieve secret details.", file=sys.stderr) | |
return | |
connection_string = format_postgres_connection_string(secret_details) | |
print(f"Your Postgres connection string is: ", file=sys.stderr, end="") | |
print(connection_string, end="") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment