Skip to content

Instantly share code, notes, and snippets.

@esev
Created August 18, 2023 21:54
Show Gist options
  • Save esev/acdab7e5258d6cbe0d60b46b2b7c43cb to your computer and use it in GitHub Desktop.
Save esev/acdab7e5258d6cbe0d60b46b2b7c43cb to your computer and use it in GitHub Desktop.
X-JWT authentication service for Home Assistant.
"""Authentication provider.
Allow access to users based login service and JWT header.
"""
from __future__ import annotations
from collections.abc import Mapping
import datetime
from typing import Any, cast
from aiohttp.web_request import Request
import jwt
import logging
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
from .. import InvalidAuthError
from ..models import RefreshToken, User
CONF_JWKS_URL = "jwks_url"
CONF_JWT_HEADER = "jwt_header"
CONF_AUDIENCE = "jwt_audience"
CONF_ISSUER = "jwt_issuer"
CONF_USER_EMAILS = "user_emails"
CONF_USER = "user"
CONF_EMAIL = "email"
DEFAULT_ISSUER = "https://[domain]/auth"
DEFAULT_JWKS_URL = f"{DEFAULT_ISSUER}/.well-known/jwks.json"
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
{
vol.Required(CONF_JWKS_URL, default=DEFAULT_JWKS_URL): str,
vol.Required(CONF_ISSUER, default=DEFAULT_ISSUER): str,
vol.Required(CONF_AUDIENCE, default="hass.[domain]"): str,
vol.Required(CONF_JWT_HEADER, default="X-Jwt"): str,
vol.Required(CONF_USER_EMAILS): [
vol.Schema(
{
vol.Required(CONF_USER): str,
vol.Required(CONF_EMAIL): str,
}
)
],
},
extra=vol.PREVENT_EXTRA,
)
_LOGGER = logging.getLogger(__name__)
class InvalidUserError(HomeAssistantError):
"""Raised when try to login as invalid user."""
@AUTH_PROVIDERS.register("xjwt")
class XjwtAuthProvider(AuthProvider):
"""X-JWT Authentication Provider.
Allow access to users based on JWT in header.
"""
DEFAULT_TITLE = "[domain] Login Authentication"
jwks_client = None
user_id_map = None
async def async_initialize(self) -> None:
"""Initialize the auth provider."""
self.jwks_client = await self.hass.async_add_executor_job(
jwt.PyJWKClient, self.config[CONF_JWKS_URL]
)
self.user_id_map = {
user[CONF_EMAIL]: user[CONF_USER] for user in self.config[CONF_USER_EMAILS]
}
async def async_login_flow(self, context: dict[str, Any] | None) -> LoginFlow:
"""Return a flow to login."""
if not (request := cast(Request, context.get("request"))):
raise InvalidAuthError("request not found in context")
if not (jwt_ := request.headers.get(self.config[CONF_JWT_HEADER])):
raise InvalidAuthError("could not get jwt header")
return XjwtLoginFlow(self, jwt_)
async def async_user_name_from_jwt(self, jwt_: str) -> str:
"""Make sure user was configured.
Raise InvalidAuthError or InvalidUserError if not.
"""
if None in (self.jwks_client, self.user_id_map):
await self.async_initialize()
try:
signing_key = await self.hass.async_add_executor_job(
self.jwks_client.get_signing_key_from_jwt, jwt_
)
claims = await self.hass.async_add_executor_job(
lambda: jwt.decode(
jwt_,
signing_key.key,
algorithms=[signing_key._jwk_data["alg"]],
issuer=self.config[CONF_ISSUER],
audience=self.config[CONF_AUDIENCE],
)
)
except jwt.PyJWTError as exc:
raise InvalidAuthError("Failed to validate JWT") from exc
if user := self.user_id_map.get(claims["sub"]):
return user
raise InvalidUserError("user not found: %s", email)
async def async_get_or_create_credentials(
self, flow_result: Mapping[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
user_name = flow_result["user"]
users = await self.store.async_get_users()
for user in users:
if user.name != user_name:
continue
if user.system_generated:
continue
if not user.is_active:
continue
for credential in await self.async_credentials():
if credential.data["user_id"] == user.id:
return credential
cred = self.async_create_credentials({"user_id": user.id})
await self.store.async_link_user(user, cred)
return cred
# We only allow login as exist user
raise InvalidUserError
class XjwtLoginFlow(LoginFlow):
"""Handler for the login flow."""
def __init__(
self,
auth_provider: XjwtAuthProvider,
jwt_: str,
) -> None:
"""Initialize the login flow."""
super().__init__(auth_provider)
self._jwt = jwt_
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle the step of the form."""
xjwt = cast(XjwtAuthProvider, self._auth_provider)
try:
user = await xjwt.async_user_name_from_jwt(self._jwt)
except InvalidAuthError:
_LOGGER.exception("failed")
return self.async_abort(reason="not_allowed")
except InvalidUserError:
_LOGGER.error("User not found: %s", user_id)
return self.async_abort(reason="not_allowed")
return await self.async_finish({"user": user})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment