Skip to content

Instantly share code, notes, and snippets.

@benc-uk
Last active February 8, 2024 13:38
Show Gist options
  • Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Python
from fastapi import FastAPI, Request, status
from fastapi.responses import PlainTextResponse
import jwt
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization")
# Response for unauthorized requests
resp401 = PlainTextResponse("Unauthorized", status_code=status.HTTP_401_UNAUTHORIZED)
# Check if Authorization header is valid
if auth_header:
# Get the token from the header
token = auth_header.split("Bearer ")[1]
if not token:
return resp401
try:
decoded_token = validate_token(token)
if decoded_token:
# Here you can do some authorization logic like checking scopes, roles, etc.
# But we don't, we just chain the request to the next middleware
response = await call_next(request)
return response
except Exception as e:
logger.error(f"ERROR: Problem validating token: {e}")
return resp401
else:
return resp401
def validate_token(token: str):
jwks_client = jwt.PyJWKClient(
# Magic URL you might want to put in a config file or constant
uri="https://login.microsoftonline.com/common/discovery/keys",
cache_jwk_set=True,
lifespan=600
)
signing_key = jwks_client.get_signing_key_from_jwt(token)
return jwt.decode(
token,
signing_key.key,
# This is the algorithm that Azure AD uses and lots of other OIDC providers
algorithms=["RS256"],
# For your API, this will be the Application ID (GUID) of the client you have registered
audience="b79fbf4d-3ef9-4689-8143-76b194e85509",
)
# Just a simple endpoint to demonstrate the middleware
@app.get("/")
def read_root():
return {"Hello": "World"}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment