Last active
August 1, 2025 09:52
-
-
Save s-zeid/864785faf5e92c041a59985491ed043c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import os | |
| import re | |
| import socket | |
| import sys | |
| import threading | |
| from dataclasses import * | |
| import pyinotify # type: ignore # apk add py3-inotify | |
| import yaml # type: ignore # apk add py3-yaml | |
| from aiohttp.http import SERVER_SOFTWARE # type: ignore # apk add py3-aiohttp | |
| from aiohttp import hdrs, web # type: ignore # apk add py3-aiohttp | |
| @dataclass | |
| class Root: | |
| routes: web.RouteTableDef = field(default_factory=web.RouteTableDef, init=False) | |
| DEFAULT_PORT = 12306 # 23: W, 06: F | |
| GROUP_NAME = "webfinger" | |
| GROUP_REL_URL_MAP = { | |
| "oidc": "http://openid.net/specs/connect/1.0/issuer", | |
| } | |
| REPO = "https://gist.github.com/864785faf5e92c041a59985491ed043c" | |
| SERVER = f"{SERVER_SOFTWARE} webfinger-authelia ({REPO})" | |
| def args(p: argparse.ArgumentParser) -> None: | |
| p.add_argument("--issuer-url", "-i", required=True, | |
| help="the issuer URL to use") | |
| p.add_argument("--user-domain", "-u", required=True, | |
| help="the domain name to use for accounts") | |
| p.add_argument("users_database_yml", | |
| help="the users_database.yml file to use") | |
| def app(options: argparse.Namespace) -> web.Application: | |
| root = WebFingerAutheliaRoot( | |
| issuer_url=options.issuer_url, | |
| users_db=AutheliaUsersDB(options.users_database_yml, options.user_domain), | |
| ) | |
| result = web.Application() | |
| result.add_routes(root.routes) | |
| # type signature bug: <https://github.com/aio-libs/aiohttp/issues/11264> | |
| # fixed in aiohttp 3.12.14; Alpine edge still has 3.11.15 | |
| result.on_response_prepare(response_headers) # type: ignore[arg-type] | |
| return result | |
| async def response_headers(req: web.Request, resp: web.Response) -> None: | |
| if resp.headers.get(hdrs.SERVER) == SERVER_SOFTWARE: | |
| resp.headers[hdrs.SERVER] = SERVER | |
| resp.headers.setdefault(hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, "*") | |
| @dataclass | |
| class WebFingerAutheliaRoot(Root): | |
| _: KW_ONLY | |
| issuer_url: str | |
| users_db: "AutheliaUsersDB" | |
| def __post_init__(self): | |
| @self.routes.view("/{_:.*}") | |
| class WebFingerView(web.View): | |
| async def get(view) -> web.Response: | |
| if not (resource := view.request.query.get("resource", "")): | |
| raise web.HTTPBadRequest(text="resource is missing or empty") | |
| if (acct := resource.removeprefix("acct:")) == resource: | |
| raise web.HTTPBadRequest(text="resource must start with `acct:`") | |
| if not (user := self.users_db.get_user(acct, match_group=GROUP_NAME)): | |
| raise web.HTTPNotFound(text="resource not found") | |
| rel = view.request.query.get("rel", "") | |
| links: list[dict[str, str]] = [] | |
| for group, rel_url in GROUP_REL_URL_MAP.items(): | |
| if group in user["groups"] and rel == rel_url: | |
| links += [{ | |
| "rel": rel_url, | |
| "href": self.issuer_url, | |
| }] | |
| break | |
| result = { | |
| "subject": resource, | |
| "links": links, | |
| } | |
| return web.json_response(result) | |
| class AutheliaUsersDB: | |
| path: str | |
| domain: str | |
| db: dict | |
| _read_lock: "threading.RLock" | |
| _watch_thread: threading.Thread | |
| def __init__(self, path: str, domain: str) -> None: | |
| self.path = path | |
| self.domain = domain | |
| self._read_lock = threading.RLock() | |
| self.db = self.read_db() | |
| self._watch_thread = threading.Thread(target=self._watch_thread_target, daemon=True) | |
| self._watch_thread.start() | |
| def get_user(self, acct: str, *, match_group: str | None = None) -> dict | None: | |
| with self._read_lock: | |
| if "@" not in acct: | |
| raise ValueError("acct must have the form of an email address") | |
| username, domain = acct.rsplit("@", 1) | |
| if domain != self.domain: | |
| return None | |
| users = self.db["users"] | |
| for try_username, try_user in users.items(): | |
| if not isinstance(users, dict): | |
| raise ValueError(f"value for user `{username}` in users database must be a dictionary") | |
| if try_username == username: | |
| if "disabled" in try_user: | |
| if not isinstance(try_user["disabled"], bool): | |
| raise ValueError(f"disabled key for user `{username}` must be a boolean") | |
| if try_user["disabled"]: | |
| return None | |
| try_user["groups"] = try_user.get("groups", []) | |
| if match_group: | |
| if try_user["groups"] == []: | |
| continue | |
| if not isinstance(try_user["groups"], list): | |
| raise ValueError(f"groups key for user `{username}` must be a list if present") | |
| if match_group in try_user["groups"]: | |
| return try_user | |
| else: | |
| return try_user | |
| return None | |
| def read_db(self) -> dict: | |
| with self._read_lock: | |
| with open(self.path, "r") as f: | |
| db = yaml.safe_load(f.read()) | |
| if not isinstance(db, dict): | |
| raise ValueError("users database must be a dictionary") | |
| if "users" not in db: | |
| raise ValueError("users database is missing users key") | |
| if not isinstance(db["users"], dict): | |
| raise ValueError("users key in users database must be a dictionary") | |
| return db | |
| def _watch_thread_target(self) -> None: | |
| wm = pyinotify.WatchManager() | |
| wm.add_watch(os.path.dirname(self.path), pyinotify.IN_CREATE | pyinotify.IN_MODIFY) | |
| class EventHandler(pyinotify.ProcessEvent): | |
| def process_default(_, event): | |
| if event.pathname == self.path and os.stat(self.path).st_size: | |
| self.db = self.read_db() | |
| notifier = pyinotify.Notifier(wm, EventHandler()) | |
| notifier.loop() | |
| def main(argv: list[str]) -> int: | |
| prog = os.path.basename(argv[0]) | |
| p = argparse.ArgumentParser( | |
| prog=prog, | |
| ) | |
| p.add_argument("--host", "-H", default=None, | |
| help="the host to listen on (default: 127.0.0.1, unless socket is given)") | |
| p.add_argument("--port", "-P", type=int, default=None, | |
| help="the port to listen on (default: 8080, unless socket is given)") | |
| p.add_argument("--socket", "-S", default=None, | |
| help="the UNIX socket to listen on (default: none)") | |
| args(p) | |
| try: | |
| options = p.parse_args(argv[1:]) | |
| except SystemExit as exc: | |
| return int(exc.code) if exc.code is not None else 127 | |
| host, port = options.host, options.port | |
| if options.socket: | |
| if not os.path.exists(os.path.dirname(options.socket)): | |
| os.makedirs(os.path.dirname(options.socket), 0o755, exist_ok=True) | |
| sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) | |
| if hasattr(socket, "SO_REUSEADDR"): | |
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| sock.bind(options.socket) | |
| sock.listen() | |
| os.chmod(options.socket, 0o777) | |
| else: | |
| host = host or "127.0.0.1" | |
| port = port or DEFAULT_PORT | |
| web.run_app( | |
| app=app(options), | |
| host=host, | |
| port=port, | |
| sock=sock if options.socket else None, | |
| access_log=None, | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| try: | |
| # pip: setproctitle; apt/dnf: python3-setproctitle; apk: py3-setproctitle | |
| from setproctitle import setproctitle as _setproctitle | |
| os.putenv("SPT_NOENV", "1") | |
| _setproctitle(" ".join([os.path.basename(sys.argv[0])] + sys.argv[1:])) | |
| os.unsetenv("SPT_NOENV") | |
| os.putenv("SPT_NOENV", os.environ["SPT_NOENV"]) | |
| except (ImportError, KeyError): | |
| pass | |
| try: | |
| sys.exit(main(sys.argv)) | |
| except KeyboardInterrupt: | |
| pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment