Skip to content

Instantly share code, notes, and snippets.

@kevlarr
Created February 13, 2020 19:56
Show Gist options
  • Save kevlarr/937b12b02295b631010f9401b8c2160c to your computer and use it in GitHub Desktop.
Save kevlarr/937b12b02295b631010f9401b8c2160c to your computer and use it in GitHub Desktop.
"""
See:
https://aws.amazon.com/premiumsupport/knowledge-center/decode-verify-cognito-json-token/
https://medium.com/datadriveninvestor/jwt-authentication-with-fastapi-and-aws-cognito-1333f7f2729e
"""
from datetime import datetime
import logging
from os import environ
import re
from typing import Any, Dict, List, Optional
from fastapi import HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import jwt, jwk, JWTError
from jose.utils import base64url_decode
from mypy_extensions import TypedDict
from pydantic import BaseModel
import requests
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED
LOG = logging.getLogger(__name__)
AWS_REGION = environ.get("AWS_REGION")
USER_POOL_ID = environ.get("USER_POOL_ID")
UNAUTHORIZED = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
KeyId = str
# Needs to remain a base dict for token parsing w/ JOSE
JWKey = TypedDict("JWKey", {
"alg": str, # Algorithm
"e": str,
"kid": KeyId,
"kty": str,
"n": str,
"use": str,
})
class JWKeys(BaseModel):
keys: List[JWKey]
# Need to request public cognito keys
__JWKS: JWKeys = JWKeys.parse_obj(requests.get(
f"https://cognito-idp.{AWS_REGION}.amazonaws.com/"
f"{USER_POOL_ID}/.well-known/jwks.json"
).json())
KEY_MAP: Dict[KeyId, JWKey] = {jwk["kid"]: jwk for jwk in __JWKS.keys}
JWtClaims = TypedDict("Claims", {
"aud": str,
"auth_time": str,
"cognito:groups": str,
"cognito:username": str,
"email": str,
"email_verified": str,
"event_id": str,
"exp": str,
"iat": str,
"iss": str,
"sub": str,
"token_use": str,
})
class JWTCredentials(BaseModel):
jwt_token: str
header: Dict[str, str]
claims: Dict[str, Any]
signature: str
message: str
class JWTBearer(HTTPBearer):
""" Custom HTTPBearer that can be used as a FastAPI dependency """
async def __call__(self, req: Request) -> Dict:
http_creds: HTTPAuthorizationCredentials = await super().__call__(req)
try:
assert http_creds and http_creds.scheme == "Bearer"
token = http_creds.credentials
message, signature = token.rsplit(".", 1)
jwt_creds = JWTCredentials(
jwt_token=token,
header=jwt.get_unverified_header(token),
claims=jwt.get_unverified_claims(token),
signature=signature,
message=message,
)
assert self.token_valid(jwt_creds)
except (AssertionError, JWTError):
raise UNAUTHORIZED
return {
"username": jwt_creds.claims["cognito:username"],
"groups": jwt_creds.claims["cognito:groups"][-1],
}
def token_valid(self, creds: JWTCredentials) -> bool:
""" Checks validity of token signature and expiry """
try:
assert creds.claims["exp"] > datetime.now().timestamp()
public_key = KEY_MAP[creds.header["kid"]]
key = jwk.construct(public_key)
signature = base64url_decode(creds.signature.encode())
assert key.verify(creds.message.encode(), signature)
return True
except (KeyError, AssertionError):
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment