Skip to content

Instantly share code, notes, and snippets.

@andreif
Last active February 19, 2024 09:29
Show Gist options
  • Save andreif/2bb761595903ea6c68ce5c1e36a8fef3 to your computer and use it in GitHub Desktop.
Save andreif/2bb761595903ea6c68ce5c1e36a8fef3 to your computer and use it in GitHub Desktop.
AWS auth via SSO OIDC
import os
from time import sleep
import boto3
from utils import dump as _print, list_all
import cache
FORCE = False
ROLE_NAME = os.environ['ROLE_NAME']
ACCOUNT_ID = os.environ['ACCOUNT_ID']
START_URL = 'https://%s.awsapps.com/start' % os.environ['OIDC_ID']
REGION = 'eu-north-1'
session = boto3.Session(region_name=REGION)
sso = session.client('sso')
sso_oidc = session.client('sso-oidc')
r = cache.get('register_client', lambda: sso_oidc.register_client(
clientName='test', clientType='public',
scopes=['sso:account:access'],
), force=FORCE)
_print(r)
client_id = r['clientId']
client_secret = r['clientSecret']
r = cache.get('start_device_authorization', lambda: sso_oidc.start_device_authorization(
clientId=client_id,
clientSecret=client_secret,
startUrl=START_URL,
), force=FORCE)
_print(r)
device_code = r['deviceCode']
interval = r['interval']
def create_token():
global interval
while True:
try:
return sso_oidc.create_token(
clientId=client_id,
clientSecret=client_secret,
grantType='urn:ietf:params:oauth:grant-type:device_code',
deviceCode=device_code,
)
except sso_oidc.exceptions.InvalidGrantException:
print("Error: Expired token")
exit(1)
except sso_oidc.exceptions.AuthorizationPendingException:
print('.', end='', flush=True)
sleep(interval)
except sso_oidc.exceptions.SlowDownException:
interval += 5
finally:
print()
token_response = cache.get('token_response', create_token, force=FORCE)
access_token = token_response['accessToken']
def get_session(account_id, role_name, region=None):
_ = sso.get_role_credentials(
accessToken=access_token,
accountId=account_id,
roleName=role_name,
)['roleCredentials']
return boto3.Session(
aws_access_key_id=_['accessKeyId'],
aws_secret_access_key=_['secretAccessKey'],
aws_session_token=_['sessionToken'],
region_name=region,
)
def get_aliases(account_id):
iam = get_session(account_id=account_id, role_name=ROLE_NAME).client('iam')
return list_all(iam, 'list_account_aliases', key='AccountAliases')
s = get_session(account_id=ACCOUNT_ID, role_name=ROLE_NAME)
for account in list_all(
s.client('organizations'), 'list_accounts',
key='Accounts', sort=lambda _: _['Name'],
):
aliases = get_aliases(account_id=account['Id'])
print(account['Id'], ' %-40s' % account['Name'], *aliases)
# if 0:
# r = s.client('support-app').get_account_alias()
import json
import os
jq = 'jq'
jq = 'yq -P' # for running locally with yq
def dump(*args):
*args, value = args
if isinstance(value, dict) and (_ := value.pop('ResponseMetadata', {})):
value['ResponseMetadata.HTTPStatusCode'] = _.get('HTTPStatusCode')
print()
print(*args)
value = json.dumps(value, default=str)
assert "'" not in value # not supported yet
os.system(f"echo '{value}' | {jq}")
def list_all(client, cmd, **kwargs):
key = kwargs.pop('key', None)
sort = kwargs.pop('sort', None)
results = [
result
for page in client.get_paginator(cmd).paginate(**kwargs)
for result in (page[key] if key else page)
]
if sort:
results = sorted(results, key=sort)
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment