Skip to content

Instantly share code, notes, and snippets.

@s-zeid
Last active August 1, 2025 09:52
Show Gist options
  • Select an option

  • Save s-zeid/864785faf5e92c041a59985491ed043c to your computer and use it in GitHub Desktop.

Select an option

Save s-zeid/864785faf5e92c041a59985491ed043c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import os
import re
import socket
import sys
import threading
from dataclasses import *
import pyinotify # type: ignore # apk add py3-inotify
import yaml # type: ignore # apk add py3-yaml
from aiohttp.http import SERVER_SOFTWARE # type: ignore # apk add py3-aiohttp
from aiohttp import hdrs, web # type: ignore # apk add py3-aiohttp
@dataclass
class Root:
routes: web.RouteTableDef = field(default_factory=web.RouteTableDef, init=False)
DEFAULT_PORT = 12306 # 23: W, 06: F
GROUP_NAME = "webfinger"
GROUP_REL_URL_MAP = {
"oidc": "http://openid.net/specs/connect/1.0/issuer",
}
REPO = "https://gist.github.com/864785faf5e92c041a59985491ed043c"
SERVER = f"{SERVER_SOFTWARE} webfinger-authelia ({REPO})"
def args(p: argparse.ArgumentParser) -> None:
p.add_argument("--issuer-url", "-i", required=True,
help="the issuer URL to use")
p.add_argument("--user-domain", "-u", required=True,
help="the domain name to use for accounts")
p.add_argument("users_database_yml",
help="the users_database.yml file to use")
def app(options: argparse.Namespace) -> web.Application:
root = WebFingerAutheliaRoot(
issuer_url=options.issuer_url,
users_db=AutheliaUsersDB(options.users_database_yml, options.user_domain),
)
result = web.Application()
result.add_routes(root.routes)
# type signature bug: <https://github.com/aio-libs/aiohttp/issues/11264>
# fixed in aiohttp 3.12.14; Alpine edge still has 3.11.15
result.on_response_prepare(response_headers) # type: ignore[arg-type]
return result
async def response_headers(req: web.Request, resp: web.Response) -> None:
if resp.headers.get(hdrs.SERVER) == SERVER_SOFTWARE:
resp.headers[hdrs.SERVER] = SERVER
resp.headers.setdefault(hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, "*")
@dataclass
class WebFingerAutheliaRoot(Root):
_: KW_ONLY
issuer_url: str
users_db: "AutheliaUsersDB"
def __post_init__(self):
@self.routes.view("/{_:.*}")
class WebFingerView(web.View):
async def get(view) -> web.Response:
if not (resource := view.request.query.get("resource", "")):
raise web.HTTPBadRequest(text="resource is missing or empty")
if (acct := resource.removeprefix("acct:")) == resource:
raise web.HTTPBadRequest(text="resource must start with `acct:`")
if not (user := self.users_db.get_user(acct, match_group=GROUP_NAME)):
raise web.HTTPNotFound(text="resource not found")
rel = view.request.query.get("rel", "")
links: list[dict[str, str]] = []
for group, rel_url in GROUP_REL_URL_MAP.items():
if group in user["groups"] and rel == rel_url:
links += [{
"rel": rel_url,
"href": self.issuer_url,
}]
break
result = {
"subject": resource,
"links": links,
}
return web.json_response(result)
class AutheliaUsersDB:
path: str
domain: str
db: dict
_read_lock: "threading.RLock"
_watch_thread: threading.Thread
def __init__(self, path: str, domain: str) -> None:
self.path = path
self.domain = domain
self._read_lock = threading.RLock()
self.db = self.read_db()
self._watch_thread = threading.Thread(target=self._watch_thread_target, daemon=True)
self._watch_thread.start()
def get_user(self, acct: str, *, match_group: str | None = None) -> dict | None:
with self._read_lock:
if "@" not in acct:
raise ValueError("acct must have the form of an email address")
username, domain = acct.rsplit("@", 1)
if domain != self.domain:
return None
users = self.db["users"]
for try_username, try_user in users.items():
if not isinstance(users, dict):
raise ValueError(f"value for user `{username}` in users database must be a dictionary")
if try_username == username:
if "disabled" in try_user:
if not isinstance(try_user["disabled"], bool):
raise ValueError(f"disabled key for user `{username}` must be a boolean")
if try_user["disabled"]:
return None
try_user["groups"] = try_user.get("groups", [])
if match_group:
if try_user["groups"] == []:
continue
if not isinstance(try_user["groups"], list):
raise ValueError(f"groups key for user `{username}` must be a list if present")
if match_group in try_user["groups"]:
return try_user
else:
return try_user
return None
def read_db(self) -> dict:
with self._read_lock:
with open(self.path, "r") as f:
db = yaml.safe_load(f.read())
if not isinstance(db, dict):
raise ValueError("users database must be a dictionary")
if "users" not in db:
raise ValueError("users database is missing users key")
if not isinstance(db["users"], dict):
raise ValueError("users key in users database must be a dictionary")
return db
def _watch_thread_target(self) -> None:
wm = pyinotify.WatchManager()
wm.add_watch(os.path.dirname(self.path), pyinotify.IN_CREATE | pyinotify.IN_MODIFY)
class EventHandler(pyinotify.ProcessEvent):
def process_default(_, event):
if event.pathname == self.path and os.stat(self.path).st_size:
self.db = self.read_db()
notifier = pyinotify.Notifier(wm, EventHandler())
notifier.loop()
def main(argv: list[str]) -> int:
prog = os.path.basename(argv[0])
p = argparse.ArgumentParser(
prog=prog,
)
p.add_argument("--host", "-H", default=None,
help="the host to listen on (default: 127.0.0.1, unless socket is given)")
p.add_argument("--port", "-P", type=int, default=None,
help="the port to listen on (default: 8080, unless socket is given)")
p.add_argument("--socket", "-S", default=None,
help="the UNIX socket to listen on (default: none)")
args(p)
try:
options = p.parse_args(argv[1:])
except SystemExit as exc:
return int(exc.code) if exc.code is not None else 127
host, port = options.host, options.port
if options.socket:
if not os.path.exists(os.path.dirname(options.socket)):
os.makedirs(os.path.dirname(options.socket), 0o755, exist_ok=True)
sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
if hasattr(socket, "SO_REUSEADDR"):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(options.socket)
sock.listen()
os.chmod(options.socket, 0o777)
else:
host = host or "127.0.0.1"
port = port or DEFAULT_PORT
web.run_app(
app=app(options),
host=host,
port=port,
sock=sock if options.socket else None,
access_log=None,
)
return 0
if __name__ == "__main__":
try:
# pip: setproctitle; apt/dnf: python3-setproctitle; apk: py3-setproctitle
from setproctitle import setproctitle as _setproctitle
os.putenv("SPT_NOENV", "1")
_setproctitle(" ".join([os.path.basename(sys.argv[0])] + sys.argv[1:]))
os.unsetenv("SPT_NOENV")
os.putenv("SPT_NOENV", os.environ["SPT_NOENV"])
except (ImportError, KeyError):
pass
try:
sys.exit(main(sys.argv))
except KeyboardInterrupt:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment