-
-
Save esev/acdab7e5258d6cbe0d60b46b2b7c43cb to your computer and use it in GitHub Desktop.
X-JWT authentication service for Home Assistant.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Authentication provider. | |
Allow access to users based login service and JWT header. | |
""" | |
from __future__ import annotations | |
from collections.abc import Mapping | |
import datetime | |
from typing import Any, cast | |
from aiohttp.web_request import Request | |
import jwt | |
import logging | |
import voluptuous as vol | |
from homeassistant.core import callback | |
from homeassistant.data_entry_flow import FlowResult | |
from homeassistant.exceptions import HomeAssistantError | |
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow | |
from .. import InvalidAuthError | |
from ..models import RefreshToken, User | |
CONF_JWKS_URL = "jwks_url" | |
CONF_JWT_HEADER = "jwt_header" | |
CONF_AUDIENCE = "jwt_audience" | |
CONF_ISSUER = "jwt_issuer" | |
CONF_USER_EMAILS = "user_emails" | |
CONF_USER = "user" | |
CONF_EMAIL = "email" | |
DEFAULT_ISSUER = "https://[domain]/auth" | |
DEFAULT_JWKS_URL = f"{DEFAULT_ISSUER}/.well-known/jwks.json" | |
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend( | |
{ | |
vol.Required(CONF_JWKS_URL, default=DEFAULT_JWKS_URL): str, | |
vol.Required(CONF_ISSUER, default=DEFAULT_ISSUER): str, | |
vol.Required(CONF_AUDIENCE, default="hass.[domain]"): str, | |
vol.Required(CONF_JWT_HEADER, default="X-Jwt"): str, | |
vol.Required(CONF_USER_EMAILS): [ | |
vol.Schema( | |
{ | |
vol.Required(CONF_USER): str, | |
vol.Required(CONF_EMAIL): str, | |
} | |
) | |
], | |
}, | |
extra=vol.PREVENT_EXTRA, | |
) | |
_LOGGER = logging.getLogger(__name__) | |
class InvalidUserError(HomeAssistantError): | |
"""Raised when try to login as invalid user.""" | |
@AUTH_PROVIDERS.register("xjwt") | |
class XjwtAuthProvider(AuthProvider): | |
"""X-JWT Authentication Provider. | |
Allow access to users based on JWT in header. | |
""" | |
DEFAULT_TITLE = "[domain] Login Authentication" | |
jwks_client = None | |
user_id_map = None | |
async def async_initialize(self) -> None: | |
"""Initialize the auth provider.""" | |
self.jwks_client = await self.hass.async_add_executor_job( | |
jwt.PyJWKClient, self.config[CONF_JWKS_URL] | |
) | |
self.user_id_map = { | |
user[CONF_EMAIL]: user[CONF_USER] for user in self.config[CONF_USER_EMAILS] | |
} | |
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow: | |
"""Return a flow to login.""" | |
if not (request := cast(Request, context.get("request"))): | |
raise InvalidAuthError("request not found in context") | |
if not (jwt_ := request.headers.get(self.config[CONF_JWT_HEADER])): | |
raise InvalidAuthError("could not get jwt header") | |
return XjwtLoginFlow(self, jwt_) | |
async def async_user_name_from_jwt(self, jwt_: str) -> str: | |
"""Make sure user was configured. | |
Raise InvalidAuthError or InvalidUserError if not. | |
""" | |
if None in (self.jwks_client, self.user_id_map): | |
await self.async_initialize() | |
try: | |
signing_key = await self.hass.async_add_executor_job( | |
self.jwks_client.get_signing_key_from_jwt, jwt_ | |
) | |
claims = await self.hass.async_add_executor_job( | |
lambda: jwt.decode( | |
jwt_, | |
signing_key.key, | |
algorithms=[signing_key._jwk_data["alg"]], | |
issuer=self.config[CONF_ISSUER], | |
audience=self.config[CONF_AUDIENCE], | |
) | |
) | |
except jwt.PyJWTError as exc: | |
raise InvalidAuthError("Failed to validate JWT") from exc | |
if user := self.user_id_map.get(claims["sub"]): | |
return user | |
raise InvalidUserError("user not found: %s", email) | |
async def async_get_or_create_credentials( | |
self, flow_result: Mapping[str, str] | |
) -> Credentials: | |
"""Get credentials based on the flow result.""" | |
user_name = flow_result["user"] | |
users = await self.store.async_get_users() | |
for user in users: | |
if user.name != user_name: | |
continue | |
if user.system_generated: | |
continue | |
if not user.is_active: | |
continue | |
for credential in await self.async_credentials(): | |
if credential.data["user_id"] == user.id: | |
return credential | |
cred = self.async_create_credentials({"user_id": user.id}) | |
await self.store.async_link_user(user, cred) | |
return cred | |
# We only allow login as exist user | |
raise InvalidUserError | |
class XjwtLoginFlow(LoginFlow): | |
"""Handler for the login flow.""" | |
def __init__( | |
self, | |
auth_provider: XjwtAuthProvider, | |
jwt_: str, | |
) -> None: | |
"""Initialize the login flow.""" | |
super().__init__(auth_provider) | |
self._jwt = jwt_ | |
async def async_step_init( | |
self, user_input: dict[str, str] | None = None | |
) -> FlowResult: | |
"""Handle the step of the form.""" | |
xjwt = cast(XjwtAuthProvider, self._auth_provider) | |
try: | |
user = await xjwt.async_user_name_from_jwt(self._jwt) | |
except InvalidAuthError: | |
_LOGGER.exception("failed") | |
return self.async_abort(reason="not_allowed") | |
except InvalidUserError: | |
_LOGGER.error("User not found: %s", user_id) | |
return self.async_abort(reason="not_allowed") | |
return await self.async_finish({"user": user}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment