Last active
May 20, 2024 23:48
-
-
Save s3rgeym/f47515fafb31b1ab0cd0295cf5bd427e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from __future__ import annotations | |
import argparse | |
import dataclasses | |
import json | |
import logging | |
import re | |
import sys | |
from pathlib import Path | |
from typing import Any, Sequence, TextIO | |
import pymysql | |
import yaml | |
CSI = "\x1b[" | |
RESET = f"{CSI}m" | |
RED = f"{CSI}31m" | |
GREEN = f"{CSI}32m" | |
YELLOW = f"{CSI}33m" | |
BLUE = f"{CSI}34m" | |
PURPLE = f"{CSI}35m" | |
class NameSpace(argparse.Namespace): | |
path: Path | |
connect_timeout: float | |
verbose: int | |
output: TextIO | |
def create_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-p", "--path", type=Path, default=Path(".")) | |
parser.add_argument("-o", "--output", type=argparse.FileType("w+"), default="-") | |
parser.add_argument("-ct", "--connect-timeout", type=float, default=10.0) | |
parser.add_argument("-v", "--verbose", action="count", default=0) | |
return parser | |
class ColorHandler(logging.StreamHandler): | |
LEVEL_COLORS = { | |
logging.DEBUG: BLUE, | |
logging.INFO: GREEN, | |
logging.WARNING: RED, | |
logging.ERROR: RED, | |
logging.CRITICAL: RED, | |
} | |
_fmt = logging.Formatter("[%(levelname).1s] %(asctime)s - %(message)s") | |
def format(self, record: logging.LogRecord) -> str: | |
message = self._fmt.format(record) | |
return f"{self.LEVEL_COLORS[record.levelno]}{message}{RESET}" | |
logger = logging.getLogger(__name__) | |
MYSQL_IMAGE_RE = re.compile("mysql|mariadb") | |
DB_KEYS_RE = re.compile( | |
r"(DB|MYSQL|POSTGRES(QL)?|PG)\w*(HOST|USER|PASS|PWD|PORT|DATABASE|NAME)" | |
) | |
def handle_compose(file_path: Path, args: NameSpace) -> None: | |
for name, data in get_compose_services(file_path).items(): | |
check_compose_service(data, name, file_path.parent.name, args) | |
def check_compose_service( | |
data: dict[str, Any], service_name: str, default_hostname: str, args: NameSpace | |
) -> None: | |
logger.debug(f"check compose service: {service_name}") | |
env_keys = data.get("environment", {}) | |
if isinstance(env_keys, list): | |
env_keys = dict(x.split("=", 1) for x in env_keys) | |
db_keys = {k: env_value(v) for k, v in env_keys.items() if DB_KEYS_RE.search(k)} | |
if not db_keys: | |
return | |
if is_mysql_service(data, db_keys): | |
db_conf = DBConfig.from_dict(db_keys) | |
db_conf.host = db_conf.host or default_hostname | |
db_conf.port = db_conf.port or compose_host_port(3306, data) | |
check_mysql(db_conf, args) | |
def check_mysql(db_conf: DBConfig, args: NameSpace) -> None: | |
try: | |
pymysql.connect( | |
host=db_conf.host, | |
user=db_conf.username, | |
password=db_conf.password, | |
database=db_conf.database, | |
connect_timeout=args.connect_timeout, | |
) | |
logger.info(f"mysql connection succeeded: {db_conf=}") | |
js = json.dumps( | |
{"service": "mysql", "config": dataclasses.asdict(db_conf)}, | |
ensure_ascii=True, | |
) | |
print(js, file=args.output, flush=True) | |
except Exception as error: | |
logger.warning(f"mysql connection failed: {db_conf=}, {error=}") | |
class Substring(str): | |
__eq__ = str.__contains__ | |
@dataclasses.dataclass | |
class DBConfig: | |
database: str = None | |
host: str = None | |
password: str = None | |
port: int = None | |
username: str = "root" | |
@classmethod | |
def from_dict(cls, dic: dict[str, Any]) -> DBConfig: | |
logger.debug(f"{dic}") | |
c = cls() | |
for k, v in dic.items(): | |
match Substring(k): | |
case "HOST": | |
c.host = v | |
case "PORT": | |
c.port = int(v) | |
case "USER": | |
c.username = v | |
case "PASS" | "PWD": | |
c.password = v | |
case "DB_NAME" | "DATABASE": | |
c.database = v | |
return c | |
def compose_host_port(container_port: int, data: dict) -> int | None: | |
if v := next( | |
(x for x in data.get("ports", []) if x.endswith(f":{container_port}")), None | |
): | |
try: | |
# ${FORWARDED_MYSQL_PORT:-3306}:3306 | |
return int(env_value(v.rsplit(":", 1)[0])) | |
except ValueError as e: | |
logger.warning(e) | |
return None | |
def is_mysql_service(data: dict, db_keys: dict) -> bool: | |
if v := data.get("image"): | |
if MYSQL_IMAGE_RE.search(v): | |
return True | |
if any(x.endswith(":3306") for x in data.get("ports", [])): | |
return True | |
if any("MYSQL" in key for key in db_keys): | |
return True | |
return False | |
def env_value(value: str) -> str: | |
""" | |
>>> env_value('${DB_NAME:-wordpress}') | |
wordpress | |
""" | |
if isinstance(value, str) and value[:2] == "${" and value[-1] == "}": | |
return value[2:-1].split(":-", 1)[1] | |
return value | |
def get_compose_services(file_path) -> dict: | |
with file_path.open() as stream: | |
try: | |
data = yaml.safe_load(stream) | |
return data["services"] | |
except Exception as e: | |
logger.warning(e) | |
return {} | |
def handle_env(file_path: Path, args: NameSpace) -> None: ... | |
SINGLE_QUOTED_PAT = r"'(?:[^']|\\\')*'" | |
DOUBLE_QUOTED_PAT = r'"(?:[^"]|\\\")*"' | |
QUOTED_PAT = f"{SINGLE_QUOTED_PAT}|{DOUBLE_QUOTED_PAT}" | |
WP_CONFIG_RE = re.compile( | |
r"define\(\s*(?P<key>" | |
+ QUOTED_PAT | |
+ r")\s*,\s*(?P<value>" | |
+ QUOTED_PAT | |
+ r")\s*\);" | |
) | |
def handle_wp_config(file_path: Path, args: NameSpace) -> None: | |
contents = file_path.read_text() | |
wp_conf: dict[str, str] = { | |
m.group("key")[1:-1]: m.group("value")[1:-1] | |
for m in WP_CONFIG_RE.finditer(contents) | |
} | |
db_conf = DBConfig.from_dict(wp_conf) | |
if db_conf.host in ("localhost", "127.0.0.1"): | |
db_conf.host = file_path.parent.name | |
check_mysql(db_conf, args) | |
HANDLERS_MAP = { | |
"**/*compose*.yml": handle_compose, | |
"**/*.env": handle_env, | |
"**/wp-config.php*": handle_wp_config, | |
} | |
def main(argv: Sequence[str] = sys.argv[1:]) -> None: | |
parser = create_parser() | |
args: NameSpace = parser.parse_args(argv) | |
lvl = max(logging.DEBUG, logging.WARNING - logging.DEBUG * args.verbose) | |
logger.setLevel(level=lvl) | |
logger.addHandler(ColorHandler()) | |
for pat, hdlr in HANDLERS_MAP.items(): | |
for path in args.path.glob(pat): | |
logger.debug(f"handle {path!s}") | |
hdlr(path, args) | |
if __name__ == "__main__": | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment