Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
"""
See:
https://aws.amazon.com/premiumsupport/knowledge-center/decode-verify-cognito-json-token/
https://medium.com/datadriveninvestor/jwt-authentication-with-fastapi-and-aws-cognito-1333f7f2729e
"""
from datetime import datetime
import logging
from os import environ
import re
from typing import Any, Dict, List, Optional
from fastapi import HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import jwt, jwk, JWTError
from jose.utils import base64url_decode
from mypy_extensions import TypedDict
from pydantic import BaseModel
import requests
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
LOG = logging.getLogger(__name__)
AWS_REGION = environ.get("AWS_REGION")
USER_POOL_ID = environ.get("USER_POOL_ID")
UNAUTHORIZED = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
KeyId = str
# Needs to remain a base dict for token parsing w/ JOSE
JWKey = TypedDict("JWKey", {
"alg": str, # Algorithm
"e": str,
"kid": KeyId,
"kty": str,
"n": str,
"use": str,
})
class JWKeys(BaseModel):
keys: List[JWKey]
# Need to request public cognito keys
__JWKS: JWKeys = JWKeys.parse_obj(requests.get(
f"https://cognito-idp.{AWS_REGION}.amazonaws.com/"
f"{USER_POOL_ID}/.well-known/jwks.json"
).json())
KEY_MAP: Dict[KeyId, JWKey] = {jwk["kid"]: jwk for jwk in __JWKS.keys}
JWtClaims = TypedDict("Claims", {
"aud": str,
"auth_time": str,
"cognito:groups": str,
"cognito:username": str,
"email": str,
"email_verified": str,
"event_id": str,
"exp": str,
"iat": str,
"iss": str,
"sub": str,
"token_use": str,
})
class JWTCredentials(BaseModel):
jwt_token: str
header: Dict[str, str]
claims: Dict[str, Any]
signature: str
message: str
class JWTBearer(HTTPBearer):
""" Custom HTTPBearer that can be used as a FastAPI dependency """
async def __call__(self, req: Request) -> Dict:
http_creds: HTTPAuthorizationCredentials = await super().__call__(req)
try:
assert http_creds and http_creds.scheme == "Bearer"
token = http_creds.credentials
message, signature = token.rsplit(".", 1)
jwt_creds = JWTCredentials(
jwt_token=token,
header=jwt.get_unverified_header(token),
claims=jwt.get_unverified_claims(token),
signature=signature,
message=message,
)
assert self.token_valid(jwt_creds)
except (AssertionError, JWTError):
raise UNAUTHORIZED
return {
"username": jwt_creds.claims["cognito:username"],
"groups": jwt_creds.claims["cognito:groups"][-1],
}
def token_valid(self, creds: JWTCredentials) -> bool:
""" Checks validity of token signature and expiry """
try:
assert creds.claims["exp"] > datetime.now().timestamp()
public_key = KEY_MAP[creds.header["kid"]]
key = jwk.construct(public_key)
signature = base64url_decode(creds.signature.encode())
assert key.verify(creds.message.encode(), signature)
return True
except (KeyError, AssertionError):
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment