Created
June 27, 2024 22:41
-
-
Save zonca/881707a069495c722edc482219da487f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export CLIENT_ID= | |
export CLIENT_SECRET= | |
export SECRET_KEY='' | |
export GITHUB_PAT='' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import datetime # to calculate expiration of the JWT | |
from fastapi import FastAPI, Depends, HTTPException, Security, Request | |
from fastapi.responses import RedirectResponse | |
from fastapi.security import APIKeyCookie # this is the part that puts the lock icon to the docs | |
from fastapi_sso.sso.github import GithubSSO | |
from fastapi_sso.sso.base import OpenID | |
import httpx | |
from github import Github | |
from jose import jwt # pip install python-jose[cryptography] | |
SECRET_KEY = os.environ["SECRET_KEY"] | |
CLIENT_ID = os.environ["CLIENT_ID"] | |
CLIENT_SECRET = os.environ["CLIENT_SECRET"] | |
sso = GithubSSO( | |
client_id=CLIENT_ID, | |
client_secret=CLIENT_SECRET, | |
redirect_uri="http://localhost:5000/auth/callback", | |
allow_insecure_http=True, | |
) | |
# This is a personal access token (classic) with the `read:org` scope | |
PAT = os.environ["GITHUB_PAT"] | |
ORG = "simonsobs" | |
app = FastAPI() | |
async def optional_get_logged_user(dataset_id: int, request: Request): | |
if dataset_id == 0: | |
return None | |
return await get_logged_user(request) | |
async def get_logged_user(request:Request) -> OpenID: | |
"""Get user's JWT stored in cookie 'token', parse it and return the user's OpenID.""" | |
cookie = await APIKeyCookie(name="token")(request) | |
#try: | |
claims = jwt.decode(cookie, key=SECRET_KEY, algorithms=["HS256"]) | |
return OpenID(**claims["pld"]) | |
#except Exception as error: | |
# raise HTTPException(status_code=401, detail="Invalid authentication credentials") from error | |
def check_user_in_org(github_username): | |
g = Github(PAT) | |
org = g.get_organization(ORG) | |
user = g.get_user(github_username) | |
return org.has_in_members(user) | |
@app.get("/protected/{dataset_id}") | |
async def protected_endpoint(dataset_id: int, request:Request, user: OpenID = Depends(optional_get_logged_user)): | |
"""This endpoint will say hello to the logged user. | |
If the user is not logged, it will return a 401 error from `get_logged_user`.""" | |
if dataset_id == 0: | |
return {"message": f"This should be public and not require authentication."} | |
else: | |
not_member = "" if check_user_in_org(user.display_name) else "not " | |
return { | |
"message": f"You are very welcome, {user.email}! You are {not_member}a member of the {ORG} organization." | |
} | |
@app.get("/auth/login") | |
async def login(): | |
"""Redirect the user to the Github login page.""" | |
with sso: | |
return await sso.get_login_redirect() | |
@app.get("/auth/logout") | |
async def logout(): | |
"""Forget the user's session.""" | |
response = RedirectResponse(url="/protected") | |
response.delete_cookie(key="token") | |
return response | |
@app.get("/auth/callback") | |
async def login_callback(request: Request): | |
"""Process login and redirect the user to the protected endpoint.""" | |
with sso: | |
openid = await sso.verify_and_process(request) | |
if not openid: | |
raise HTTPException(status_code=401, detail="Authentication failed") | |
# Create a JWT with the user's OpenID | |
expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(days=1) | |
token = jwt.encode({"pld": openid.dict(), "exp": expiration, "sub": openid.id}, key=SECRET_KEY, algorithm="HS256") | |
response = RedirectResponse(url="/protected/1") | |
response.set_cookie( | |
key="token", value=token, expires=expiration | |
) # This cookie will make sure /protected knows the user | |
return response | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="127.0.0.1", port=5000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment