Created
September 25, 2022 14:07
-
-
Save gessfred/024c4c356b774d9f55a2b29328dedd94 to your computer and use it in GitHub Desktop.
Code sample to do JWT-based auth (Auth0) on cross-origin endpoints with FastAPI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Response, Request, Depends, Header, Cookie, HTTPException, status, WebSocket | |
from fastapi.responses import JSONResponse, PlainTextResponse | |
from jose import jwt | |
AUTH0_DOMAIN='xxxx.us.auth0.com' | |
ALGORITHMS = ["HS256", "RS256"] | |
API_AUDIENCE = 'https://api.example.com/' | |
def decode_auth_header(auth): | |
token = auth.split() | |
assert len(token) == 2 | |
return token[1] | |
def get_token_payload(token: str): | |
unverified_header = jwt.get_unverified_header(token) | |
rsa_keys = requests.get(f'https://{AUTH0_DOMAIN}/.well-known/jwks.json').json() | |
rsa_keys = {k['kid']: k for k in rsa_keys['keys']} | |
if unverified_header['alg'] not in ALGORITHMS: | |
raise Exception("Unsupported JWT algorithms") | |
payload = jwt.decode( | |
token, | |
rsa_keys[unverified_header['kid']], | |
algorithms=unverified_header['alg'], | |
audience=API_AUDIENCE, | |
issuer=f"https://{AUTH0_DOMAIN}/" | |
) | |
return payload | |
@app.middleware("http") | |
async def authorize_request(request: Request, call_next): | |
origin = request.headers["Referer"] | |
if origin.endswith("/"): origin = origin[:-1] | |
cors_headers = { | |
"Access-Control-Allow-Methods": "GET, POST, OPTIONS", | |
"Access-Control-Allow-Credentials": "true", | |
"Access-Control-Allow-Origin": origin, | |
"Access-Control-Allow-Headers": "Origin, X-Requested-With, Content-Type, Accept, Authorization", | |
"Access-Control-Max-Age": "86400" | |
} | |
if request.method == "OPTIONS": | |
if not ("example.com" in origin): | |
print(origin, "not in allowed origins") | |
return PlainTextResponse("CORS error", status_code=401) | |
return PlainTextResponse( | |
"OK", | |
status_code=200, | |
headers=cors_headers | |
) | |
if 'Authorization' not in request.headers: | |
return JSONResponse(status_code=401, content={'detail': "No credentials found (authorization bearer or cookie)"}) | |
token = decode_auth_header(request.headers['Authorization']) | |
payload = get_token_payload(token) | |
if 'https://api.example.com/' not in payload['aud']: | |
raise Exception("unauthorized") | |
response = await call_next(request) | |
for h in cors_headers: | |
response.headers[h] = cors_headers[h] | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment