Skip to content

Instantly share code, notes, and snippets.

@scott2b
Last active November 30, 2022 16:32
Show Gist options
  • Save scott2b/4a27db553ba4f94f8fb094a9436e90fd to your computer and use it in GitHub Desktop.
Save scott2b/4a27db553ba4f94f8fb094a9436e90fd to your computer and use it in GitHub Desktop.
A simple authentication proxy for mlauth implemented in Starlette
"""
****
CAUTION: This is the wrong approach and will not work for all requests!
Instead of parsing out the json data, the content body should just simply be passed
through to mlflow. I will get a fix in soon.
****
A simple proxy to put an authentication layer in front of mlflow.
requires [Starlette](https://www.starlette.io/)
mlflow does not provide authentication and it is a common use case to need to secure
mlflow deployments. There are some mlflow specific things here, but this gist could
easily be adapted to other similar use cases.
I wrote most of this before I ran across the mlflow-easyauth project. That may be a better
choice, although I have not tried it. Glancing at that code, it is not clear to me how
mlflow-easyauth handles the UI which does not have a mechanism for passing authentication
in requests, as far as I know. The approach here is to introduce login and logout
endpoints in the proxy along with a middleware. A session is created for the user
after checking the Basic authentication credentials. It is destroyed by visiting /logout.
Unfortunately, I do not see a way to add login and logout to the UI navigation.
If you specifically are using OAuth, you might consider something like
[OAuth2 Proxy](https://oauth2-proxy.github.io/oauth2-proxy/) for mlflow, although I
have not tried it myself, and as far as I can tell the mlflow clients do not have a
mechanism for passing credentials other than Basic auth headers, so you would probably
have to implement your own client.
The general assumption here is that SITE_ROOT is only internally accessible. The
mechanism for protecting that endpoint will be specific to your setup, but will
typically be something network-y like firewall rules, IP allow-lists, etc.
## Env variables
The following need to be set in the environment of your client for programmatic API access:
MLFLOW_TRACKING_USERNAME=username
MLFLOW_TRACKING_PASSWORD=password
Execute with:
```
uvicorn mlflowproxy:app --host 0.0.0.0 --port 5002
```
## Related resources
- [mlflow](https://mlflow.org/)
- [Starlette](https://www.starlette.io/)
- [mlflow-easyauth](https://github.com/soundsensing/mlflow-easyauth). Basic auth for mlflow
"""
import json
import httpx
from base64 import b64decode
from starlette.applications import Starlette
from starlette.authentication import requires
from starlette.routing import Route
from starlette.responses import Response, RedirectResponse
from starlette.authentication import AuthenticationBackend, AuthCredentials
from starlette.authentication import SimpleUser, BaseUser, UnauthenticatedUser
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import HTMLResponse
DEBUG = True
PROXY_ROOT = "http://localhost:5002" # Set these for your proxy and site.
SITE_ROOT = "http://localhost:5000" # See notes above about protecting your site.
NO_DATA_METHODS = ["GET", "HEAD"]
DATA_METHODS = [
"POST",
"PUT",
"PATCH",
"DELETE",
] # mlflow sends data for delete requests
ALL_METHODS = NO_DATA_METHODS + DATA_METHODS
REMOVE_REQUEST_HEADERS = ["content-length"]
SECRET_KEY = "supersecretchangeme" # used by session middleware
def decode_password(header):
basic, cred = header.split()
assert basic.lower() == "basic"
return b64decode(cred).decode().split(":", 1)
def clean_request_headers(headers):
return {k: v for k, v in headers.items() if k not in REMOVE_REQUEST_HEADERS}
async def json_data(request):
try:
return await request.json()
except (AttributeError, json.decoder.JSONDecodeError):
raise
return None
# Probably don't do this in production
USERS = {
"user1": "password1",
"user2": "password2"
}
class DevAuthBackend(AuthenticationBackend):
"""For dev and testing, assumes username=username and password=password."""
async def authenticate(self, conn):
"""Note that Starlette's sessions send session data to the client, so the
username is readable but not modifiable. You may want to use an alternative
session middleware that stores session data on the server or otherwise obfuscate
the user's identity in the session data.
"""
if "username" in conn.session and conn.session["username"] in USERS:
return AuthCredentials(["api_auth", "app_auth"]), SimpleUser(
conn.session["username"]
)
if "authorization" not in conn.headers:
return
username, password = decode_password(conn.headers["authorization"])
if password == USERS.get(username):
conn.session["username"] = username
return AuthCredentials(["api_auth", "app_auth"]), SimpleUser(username)
class AuthBackend(AuthenticationBackend):
async def authenticate(self, request):
"""Implement this with your authentication requirements. Presumably, you would
do something a bit more sophisticated here than a hard-coded dictionary.
See: https://www.starlette.io/authentication/
"""
raise NotImplementedError
# AUTH_BACKEND = AuthBackend # Needs to be implemented
AUTH_BACKEND = DevAuthBackend # For dev purposes only!
async def dispatch(url, method, **kwargs):
async with httpx.AsyncClient() as client:
return await client.request(url=url, method=method, **kwargs)
async def main(request):
if not request.user.is_authenticated:
return RedirectResponse(request.url_for("login"))
if "app_auth" not in request.auth.scopes:
return RedirectResponse(request.url_for("logout"))
url = SITE_ROOT + str(request.url)[len(PROXY_ROOT) :]
req_headers = clean_request_headers(request.headers)
if request.method in NO_DATA_METHODS:
resp = await dispatch(url, request.method, headers=req_headers)
elif request.method in DATA_METHODS:
resp = await dispatch(
url, request.method, headers=req_headers, json=await json_data(request)
)
return Response(
resp.content,
headers=dict(resp.headers),
media_type=resp.headers.get("content-type"),
status_code=resp.status_code,
)
@requires("api_auth")
async def api(request):
return await main(request)
async def logout(request):
request.session.clear()
return RedirectResponse(request.url_for("login"))
async def login(request):
if request.method == "POST":
if request.user.is_authenticated:
return Response(status_code=200)
else:
return Response(status_code=403)
if request.user.is_authenticated:
return RedirectResponse(request.url_for("home"))
return HTMLResponse(
content="""<html>
<head>
</head>
<body>
<form id="loginForm">
<input type="text" name="username" placeholder="username">
<input type="password" name="password" placeholder="password">
<input type="submit" value="Login">
</form>
<div id="message" />
<script>
loginForm.onsubmit = async (e) => {
e.preventDefault();
let data = new FormData(loginForm);
let response = await fetch('/login', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Basic ' + btoa(data.get("username") + ':' + data.get("password"))
}
});
if (response.status == 200) {
window.location.replace("/");
} else if (response.status == 403) {
message.innerText = "Incorrect username or password";
} else {
message.innerText = "Unknown error. Unable to log in.";
}
};
</script>
</body>
</html>
""",
status_code=200,
)
"""
Routing here is specific to mlflow and is designed to allow for api-key access to the
API with the auth proper backend implementation above. For use cases other than mlflow,
if there is not an API, remove the ajax-api and api routes and all requests will go to main.
mlflow appears to have multiple internal API roots which are handled here.
"""
routes = [
Route("/login", login, methods=["GET", "POST"], name="login"),
Route("/logout", logout, methods=["GET"], name="logout"),
Route("/ajax-api/{path:path}", api, methods=ALL_METHODS),
Route("/api/{path:path}", api, methods=ALL_METHODS),
Route("/{path:path}", main, methods=ALL_METHODS),
Route("/", main, methods=ALL_METHODS, name="home"),
]
app = Starlette(debug=DEBUG, routes=routes)
app.add_middleware(AuthenticationMiddleware, backend=AUTH_BACKEND())
# See the Starlette SessionMiddleware docs: https://www.starlette.io/middleware/#sessionmiddleware
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, https_only=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment