Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Python3.6 AWS Api Gateway Lambda based JWT authentication code and lambda type hints.

AWS Lambda JWT authenticaiton

AWS API Gateway has the ability to pre-authenticate connections prior to launching the endpoint, by passing the authorizationToken to a Lambda function. There are clear benefits for simplifying end point security and also a reduction in duplicated code by utilising this feature. However I found the AWS examples were excessively complicated for what should be a very simple task.
So here's my example.

The main concern is that AWS Lambda authentication expects a very specific response and if that response is not given it will throw a 500 error with x-amzn-ErrorType: AuthorizerConfigurationException in the response header if the response object is not exactly as expected.

Usage

I personally use https://zappa.io to handle the publishing part of my Lambdas, but I'll include an image of the API Gateway config.

import logging
from typing import Union, Optional
import jwt
import os
from lambda_types import LambdaDict, LambdaContext
def generate_policy(principal_id: Union[int, str, None], effect: str, method_arn: str) -> dict:
""" return a valid AWS policy response """
auth_response = {'principalId': principal_id}
if effect and method_arn:
policy_document = {
'Version': '2012-10-17',
'Statement': [
{
'Sid': 'FirstStatement',
'Action': 'execute-api:Invoke',
'Effect': effect,
'Resource': method_arn
}
]
}
auth_response['policyDocument'] = policy_document
return auth_response
def decode_auth_token(auth_token: str) -> Optional[dict]:
""" Decodes the auth token """
try:
# remove "Bearer " from the token string.
auth_token = auth_token.replace('Bearer ', '')
# decode using system environ $SECRET_KEY, will crash if not set.
return jwt.decode(auth_token.encode(), os.environ['SECRET_KEY'])
except jwt.ExpiredSignatureError:
'Signature expired. Please log in again.'
return
except jwt.InvalidTokenError:
'Invalid token. Please log in again.'
return
def lambda_handler(event: LambdaDict, context: LambdaContext) -> dict:
try:
auth_token = event.get('authorizationToken')
method_arn = event.get('methodArn')
if auth_token and method_arn:
# verify the JWT
user_details = decode_auth_token(auth_token)
if user_details:
# if the JWT is valid and not expired return a valid policy.
return generate_policy(user_details.get('id'), 'Allow', method_arn)
except Exception as e:
logging.exception(e)
return {
'error': f"{type(e).__name__}:{e}"
}
return generate_policy(None, 'Deny', method_arn)
# this was taken from https://gist.github.com/alexcasalboni/a545b68ee164b165a74a20a5fee9d133
from typing import Dict, Any
LambdaDict = Dict[str, Any]
class LambdaCognitoIdentity(object):
cognito_identity_id: str
cognito_identity_pool_id: str
class LambdaClientContextMobileClient(object):
installation_id: str
app_title: str
app_version_name: str
app_version_code: str
app_package_name: str
class LambdaClientContext(object):
client: LambdaClientContextMobileClient
custom: LambdaDict
env: LambdaDict
class LambdaContext(object):
function_name: str
function_version: str
invoked_function_arn: str
memory_limit_in_mb: int
aws_request_id: str
log_group_name: str
log_stream_name: str
identity: LambdaCognitoIdentity
client_context: LambdaClientContext
@staticmethod
def get_remaining_time_in_millis() -> int:
return 0
@men6288
Copy link

men6288 commented Jan 29, 2022

Hi @bendog

This example is really helpful, I get the idea of decoding the tokens but how are they generated? I am not too well versed in Oauth and I am currently trying to set this up on our API Gateway so anything helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment