Skip to content

Instantly share code, notes, and snippets.

@linuskohl
Last active December 30, 2022 10:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save linuskohl/024a487c2435ba1287e2d1c9d7406aea to your computer and use it in GitHub Desktop.
Save linuskohl/024a487c2435ba1287e2d1c9d7406aea to your computer and use it in GitHub Desktop.
Helper functions to validate JSON Web Tokens for flask RESTful APIs by fetching JWKs from OpenID Provider Metadata. Used with Okta.
from functools import wraps
from flask import request, abort, g
import json
import jwt
import requests
from typing import Union, List
from ..config import cache
from ..env import JWT_ISSUER, JWT_CLIENTID, JWT_AUDIENCE
DISCOVERY_URL = "/.well-known/oauth-authorization-server"
def login_required(f):
"""
Decorator to load JWT and globally sets user and uid.
JW Tokens are verified to match the issuer, audience and signature.
"""
@wraps(f)
def wrap(*args, **kwargs):
authorization = request.headers.get("authorization", None)
if not authorization:
abort(403)
try:
token_raw = authorization.split(' ')[1]
key_id = jwt.get_unverified_header(token_raw)['kid']
jwk = get_jwk(JWT_ISSUER, JWT_CLIENTID, key_id, cache)
token = jwt.decode(token_raw,
jwk,
verify=True,
issuer=JWT_ISSUER,
audience=JWT_AUDIENCE,
algorithms=['RS256'])
g.user = token['sub']
g.user_id = token['uid']
print(g.user_id)
except Exception as e:
abort(403)
return f(*args, **kwargs)
return wrap
def get_jwk(issuer: str, client_id: str, kid: str, cache=None):
"""
Get JWK with key id
Args:
issuer(str): JWT Issuer
client_id(str): JWT Client ID
kid(str): Key ID
cache(Cache): Cache object to store keys
Returns:
Dict: JWT
"""
# try to load from cache
key = None
if cache:
key = cache.get(kid)
if key is None:
keys = fetch_jwks_for(issuer, client_id)
for k in keys:
# persist all keys
if cache:
cache.set(kid, k)
if k['kid'] == kid:
key = k
if key:
return jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key))
raise Exception
def fetch_jwks_for(issuer: str, client_id: str) -> Union[None, List]:
"""
Get JWKs from OpenID Provider Metadata
Args:
issuer(str): JWT Issuer
client_id(str): JWT Client ID
Returns:
List: List of key objects
"""
oidp_metadata = fetch_metadata_for(issuer, client_id)
jwks_uri = oidp_metadata.get('jwks_uri')
jwks = requests.get(jwks_uri)
return jwks.json().get('keys')
def fetch_metadata_for(issuer: str, client_id: str) -> dict:
"""
Get OpenID Provider Metadata information
Args:
issuer(str): JWT Issuer
client_id(str): JWT Client ID
Returns:
dict: OpenID Provider Metadata
"""
url = issuer + DISCOVERY_URL
data = {'client_id': client_id}
r = requests.get(url, params=data)
r.raise_for_status()
return r.json()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment