Skip to content

Instantly share code, notes, and snippets.

@ergo70 ergo70/
Last active Mar 15, 2020

What would you like to do?
pam-python PAM module for AWS Cognito
import boto3
import jwt
from json import dumps
from requests import get
from botocore import UNSIGNED
from botocore.config import Config
from warrant.aws_srp import AWSSRP
from cryptography.hazmat.primitives import serialization
def _get_credentials(pamh):
user = None
password = None
user = pamh.get_user(None)
if pamh.authtok == None: # If authtok not set, start conversation to get it
msg = pamh.Message(pamh.PAM_PROMPT_ECHO_OFF, _PAM_DUMMY_PROMPT)
conv = pamh.conversation(msg)
pamh.authtok = conv.resp
except pamh.exception as e:
return None
password = pamh.authtok
return user, password
def _get_public_keys_for_cognito_pool(region, pool_id):
pubkeys = get('https://cognito-idp.{}{}/.well-known/jwks.json'.format(
region, pool_id))
pubkeys = pubkeys.json()
pubkeys = pubkeys.get('keys')
return pubkeys
def _jwks2PEM(jwks_key):
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(dumps(jwks_key))
PEM_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
return PEM_key
def _verify_claim(region, token, pool_id, client_id=None):
pubkeys = _PUBKEYS
if not pubkeys:
pubkeys = _get_public_keys_for_cognito_pool(region, pool_id)
claim_header = jwt.get_unverified_header(token)
for key in pubkeys:
if key.get('kid') == claim_header.get('kid'):
verification_algorithm = key.get('alg')
verification_key = _jwks2PEM(key)
claim = jwt.decode(token, verification_key,
algorithm=verification_algorithm, audience=client_id)
return claim
return None
def _cognito_auth_handler(region, pool_id, client_id, user, password):
if None in (region, pool_id, client_id, user, password):
return False
client = boto3.client('cognito-idp',
region_name=region, config=Config(signature_version=UNSIGNED))
aws = AWSSRP(username=user, password=password, pool_id=pool_id,
client_id=client_id, client=client)
tokens = aws.authenticate_user()
if tokens:
tokens = tokens.get('AuthenticationResult')
id_token = tokens.get('IdToken')
access_token = tokens.get('AccessToken')
id_claim = _verify_claim(region, id_token, pool_id, client_id)
access_claim = _verify_claim(region, access_token, pool_id)
return not (None in (id_claim, access_claim))
return False
def pam_sm_authenticate(pamh, flags, argv):
if len(argv) != 4:
return pamh.PAM_AUTH_ERR
user, password = _get_credentials(pamh)
if _cognito_auth_handler(argv[1], argv[2], argv[3], user, password):
return pamh.PAM_SUCCESS
return pamh.PAM_AUTH_ERR
# Authorize user account. Type: account
def pam_sm_acct_mgmt(pamh, flags, argv):
return pamh.PAM_SUCCESS
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.