Skip to content

Instantly share code, notes, and snippets.

@mvanderlee
Last active May 7, 2020 18:06
Show Gist options
  • Save mvanderlee/37534a3c2dfe63f892cb43891302f317 to your computer and use it in GitHub Desktop.
Save mvanderlee/37534a3c2dfe63f892cb43891302f317 to your computer and use it in GitHub Desktop.
AWS CLI utilities
'''
Configures AWS CLI config with the MFA session token.
Recommended usage:
* Set 'MFA_ARN' env variable
* Run `python aws_mfa.py`. Script will prompt you for the mfa code.
'''
import boto3
import click
import coloredlogs
import configparser
import logging
import os
from environs import Env
coloredlogs.install(
level=logging.INFO,
fmt='%(message)s',
level_styles=dict(coloredlogs.DEFAULT_LEVEL_STYLES, **{
'info': {'color': 'green'},
'debug': {'color': 'blue'}
})
)
logger = logging.getLogger(__name__)
Env().read_env(verbose=True) # Load .env file
def read_ini_file_to_dict(file_path):
'''
Read ini file and return it's content as a dict
'''
# Read ini file
parser = configparser.ConfigParser()
parser.read(file_path)
# Convert to dictionary
config_dict = {section: dict(parser.items(section)) for section in parser.sections()}
return config_dict
def write_dict_to_ini_file(dict_, file_path):
'''
Write dict to ini file
'''
parser = configparser.ConfigParser()
parser.read_dict(dict_)
with open(file_path, 'w') as fout:
parser.write(fout)
@click.command()
@click.option('-a', '--mfa-arn', envvar='MFA_ARN', prompt='MFA arn', help='The identification number of the MFA device that is associated with the IAM user. i.e.: "arn:aws:iam::123456789012:mfa/tony.stark". You can find this on the IAM page.')
@click.option('-c', '--code', envvar='CODE', prompt='Code', help='The code generated by your MFA device.')
@click.option('-sp', '--source-profile', default='mfa-source', envvar='SOURCE_PROFILE', prompt='Source Profile', help='What AWS profile to get the session token with.')
@click.option('-tp', '--target-profile', default='default', envvar='TARGET_PROFILE', prompt='Target Profile', help='What AWS profile to store the credentials under.')
def cli(
mfa_arn,
code,
source_profile,
target_profile,
**kwargs
):
# Get MFA creds
session = boto3.session.Session(profile_name=source_profile)
sts = session.client('sts')
session_token = sts.get_session_token(
SerialNumber=mfa_arn,
TokenCode=code
)
session_creds = session_token['Credentials']
# Get creds from user's home
creds_path = os.path.join(os.path.expanduser('~'), '.aws', 'credentials')
aws_creds = read_ini_file_to_dict(creds_path)
# Update creds
aws_creds.update({
target_profile: {
'aws_access_key_id': session_creds['AccessKeyId'],
'aws_secret_access_key': session_creds['SecretAccessKey'],
'aws_session_token': session_creds['SessionToken']
}
})
write_dict_to_ini_file(aws_creds, creds_path)
logger.info(f'Updated {creds_path}, use profile {target_profile} for your AWS requests.')
'''
Asks user which Cluster and EC2 instance they want to connect to,
then prints the ssh options required to connect to the selected instance.
Create this alias on your unix system.
alias emr_ssh='ssh `emr ssh`'
Then just run: emr_ssh
'''
import boto3
import click
import click_spinner
import coloredlogs
import logging
import os
import questionary
from botocore.exceptions import ClientError
from environs import Env
logger = None
Env().read_env() # Load .env file
def init_logging(long_log=False, debug=False, verbose=False, quiet=False):
if quiet:
logging.disable(logging.ERROR)
coloredlogs_config = {
'level': logging.CRITICAL,
'fmt': '%(message)s'
}
coloredlogs.install(**coloredlogs_config)
logger = logging.getLogger(__name__)
return logger
coloredlogs_config = {
'level': logging.INFO,
'level_styles': dict(coloredlogs.DEFAULT_LEVEL_STYLES, **{
'info': {'color': 'green'},
'debug': {'color': 'blue'}
})
}
if not long_log:
coloredlogs_config['fmt'] = '%(message)s'
coloredlogs.install(**coloredlogs_config)
boto3_cred_logger = logging.getLogger('botocore.credentials')
boto3_cred_logger.setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
if debug:
coloredlogs.set_level(logging.DEBUG)
logger.setLevel(logging.DEBUG)
logger.debug('Debug enabled')
if verbose and not debug:
coloredlogs.set_level(logging.DEBUG)
logger.setLevel(logging.DEBUG)
logger.debug('Verbose enabled')
# Set all other loggers to info. We only want our own logger to be verbose
for name, sub_logger in logger.manager.loggerDict.items():
if name != __name__ and not isinstance(sub_logger, logging.PlaceHolder):
sub_logger.setLevel(logging.INFO)
return logger
def get_emr_clusters():
emr = boto3.client('emr')
clusters = emr.list_clusters(
ClusterStates=['RUNNING', 'WAITING']
)
return {c['Name']: c['Id'] for c in clusters['Clusters']}
def get_emr_instances(cluster_id):
emr = boto3.client('emr')
groups = emr.list_instance_groups(ClusterId=cluster_id)
groups = {g['Id']: g['Name'] for g in groups['InstanceGroups']}
grouped_instances = {g: list() for g in groups.values()}
instances = emr.list_instances(ClusterId=cluster_id, InstanceStates=['RUNNING'])
for instance in instances['Instances']:
grouped_instances[groups[instance['InstanceGroupId']]].append(instance['PrivateIpAddress'])
return grouped_instances
def get_instance_key_name(cluster_id):
emr = boto3.client('emr')
cluster = emr.describe_cluster(ClusterId=cluster_id)
key_name = cluster['Cluster']['Ec2InstanceAttributes']['Ec2KeyName']
return key_name
def try_to_find_ssh_key_file(key_name):
for dirpath, dirnames, filenames in os.walk(os.path.expanduser('~/.ssh/')):
if filenames:
for f in filenames:
if os.path.splitext(f)[0] == key_name:
return os.path.join(dirpath, f)
return None
def get_ssh_options():
with click_spinner.spinner():
clusters = get_emr_clusters()
cluster_name = questionary.select(
'Which cluster do you want to connect to?',
choices=sorted(list(clusters.keys()))
).unsafe_ask()
cluster_id = clusters[cluster_name]
with click_spinner.spinner():
grouped_instances = get_emr_instances(cluster_id)
instance_group = questionary.select(
'Which instance group do you want to connect to?',
choices=sorted(list(grouped_instances.keys()))
).unsafe_ask()
instances = grouped_instances[instance_group]
instance_ip = instances[0]
if len(instances) > 1:
instance_ip = questionary.select(
'Which instance do you want to connect to?',
choices=sorted(instances)
).unsafe_ask()
with click_spinner.spinner():
key_name = get_instance_key_name(cluster_id)
key_file = try_to_find_ssh_key_file(key_name)
if key_file is None:
should_continue = questionary.confirm(f'Could not find the ssh key {key_name}, would you like to continue?').unsafe_ask()
if not should_continue:
exit(1)
if key_file:
return f'-i {key_file} hadoop@{instance_ip}'
else:
return f'hadoop@{instance_ip}'
@click.group()
@click.option('-ll/', '--long-log/--no-long-log', default=False, help='Enable root debug logging')
@click.option('--debug/--no-debug', default=False, help='Enable root debug logging')
@click.option('--verbose/--no-verbose', default=False, help='Enable debug logging')
@click.option('--quiet/', default=True, help='Disable logging')
def cli(long_log, debug, verbose, quiet):
global logger
logger = init_logging(long_log, debug, verbose, quiet)
@cli.command('ssh')
def ssh(**kwargs):
'''
Asks user which Cluster and EC2 instance they want to connect to,
then prints the ssh options required to connect to the selected instance.
Create this alias on your unix system.
alias emr_ssh='ssh `emr ssh`'
'''
try:
ssh_options = get_ssh_options()
print(ssh_options)
except ClientError as e:
if e.response.get('Error') and e.response['Error'].get('Code') == 'ExpiredTokenException':
logger.log(logging.CRITICAL, 'Your AWS Token has expired. Please update and try again.')
exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment