Skip to content

Instantly share code, notes, and snippets.

@s3rgeym
Last active May 20, 2024 23:48
Show Gist options
  • Save s3rgeym/f47515fafb31b1ab0cd0295cf5bd427e to your computer and use it in GitHub Desktop.
Save s3rgeym/f47515fafb31b1ab0cd0295cf5bd427e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
from __future__ import annotations
import argparse
import dataclasses
import json
import logging
import re
import sys
from pathlib import Path
from typing import Any, Sequence, TextIO
import pymysql
import yaml
CSI = "\x1b["
RESET = f"{CSI}m"
RED = f"{CSI}31m"
GREEN = f"{CSI}32m"
YELLOW = f"{CSI}33m"
BLUE = f"{CSI}34m"
PURPLE = f"{CSI}35m"
class NameSpace(argparse.Namespace):
path: Path
connect_timeout: float
verbose: int
output: TextIO
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=Path, default=Path("."))
parser.add_argument("-o", "--output", type=argparse.FileType("w+"), default="-")
parser.add_argument("-ct", "--connect-timeout", type=float, default=10.0)
parser.add_argument("-v", "--verbose", action="count", default=0)
return parser
class ColorHandler(logging.StreamHandler):
LEVEL_COLORS = {
logging.DEBUG: BLUE,
logging.INFO: GREEN,
logging.WARNING: RED,
logging.ERROR: RED,
logging.CRITICAL: RED,
}
_fmt = logging.Formatter("[%(levelname).1s] %(asctime)s - %(message)s")
def format(self, record: logging.LogRecord) -> str:
message = self._fmt.format(record)
return f"{self.LEVEL_COLORS[record.levelno]}{message}{RESET}"
logger = logging.getLogger(__name__)
MYSQL_IMAGE_RE = re.compile("mysql|mariadb")
DB_KEYS_RE = re.compile(
r"(DB|MYSQL|POSTGRES(QL)?|PG)\w*(HOST|USER|PASS|PWD|PORT|DATABASE|NAME)"
)
def handle_compose(file_path: Path, args: NameSpace) -> None:
for name, data in get_compose_services(file_path).items():
check_compose_service(data, name, file_path.parent.name, args)
def check_compose_service(
data: dict[str, Any], service_name: str, default_hostname: str, args: NameSpace
) -> None:
logger.debug(f"check compose service: {service_name}")
env_keys = data.get("environment", {})
if isinstance(env_keys, list):
env_keys = dict(x.split("=", 1) for x in env_keys)
db_keys = {k: env_value(v) for k, v in env_keys.items() if DB_KEYS_RE.search(k)}
if not db_keys:
return
if is_mysql_service(data, db_keys):
db_conf = DBConfig.from_dict(db_keys)
db_conf.host = db_conf.host or default_hostname
db_conf.port = db_conf.port or compose_host_port(3306, data)
check_mysql(db_conf, args)
def check_mysql(db_conf: DBConfig, args: NameSpace) -> None:
try:
pymysql.connect(
host=db_conf.host,
user=db_conf.username,
password=db_conf.password,
database=db_conf.database,
connect_timeout=args.connect_timeout,
)
logger.info(f"mysql connection succeeded: {db_conf=}")
js = json.dumps(
{"service": "mysql", "config": dataclasses.asdict(db_conf)},
ensure_ascii=True,
)
print(js, file=args.output, flush=True)
except Exception as error:
logger.warning(f"mysql connection failed: {db_conf=}, {error=}")
class Substring(str):
__eq__ = str.__contains__
@dataclasses.dataclass
class DBConfig:
database: str = None
host: str = None
password: str = None
port: int = None
username: str = "root"
@classmethod
def from_dict(cls, dic: dict[str, Any]) -> DBConfig:
logger.debug(f"{dic}")
c = cls()
for k, v in dic.items():
match Substring(k):
case "HOST":
c.host = v
case "PORT":
c.port = int(v)
case "USER":
c.username = v
case "PASS" | "PWD":
c.password = v
case "DB_NAME" | "DATABASE":
c.database = v
return c
def compose_host_port(container_port: int, data: dict) -> int | None:
if v := next(
(x for x in data.get("ports", []) if x.endswith(f":{container_port}")), None
):
try:
# ${FORWARDED_MYSQL_PORT:-3306}:3306
return int(env_value(v.rsplit(":", 1)[0]))
except ValueError as e:
logger.warning(e)
return None
def is_mysql_service(data: dict, db_keys: dict) -> bool:
if v := data.get("image"):
if MYSQL_IMAGE_RE.search(v):
return True
if any(x.endswith(":3306") for x in data.get("ports", [])):
return True
if any("MYSQL" in key for key in db_keys):
return True
return False
def env_value(value: str) -> str:
"""
>>> env_value('${DB_NAME:-wordpress}')
wordpress
"""
if isinstance(value, str) and value[:2] == "${" and value[-1] == "}":
return value[2:-1].split(":-", 1)[1]
return value
def get_compose_services(file_path) -> dict:
with file_path.open() as stream:
try:
data = yaml.safe_load(stream)
return data["services"]
except Exception as e:
logger.warning(e)
return {}
def handle_env(file_path: Path, args: NameSpace) -> None: ...
SINGLE_QUOTED_PAT = r"'(?:[^']|\\\')*'"
DOUBLE_QUOTED_PAT = r'"(?:[^"]|\\\")*"'
QUOTED_PAT = f"{SINGLE_QUOTED_PAT}|{DOUBLE_QUOTED_PAT}"
WP_CONFIG_RE = re.compile(
r"define\(\s*(?P<key>"
+ QUOTED_PAT
+ r")\s*,\s*(?P<value>"
+ QUOTED_PAT
+ r")\s*\);"
)
def handle_wp_config(file_path: Path, args: NameSpace) -> None:
contents = file_path.read_text()
wp_conf: dict[str, str] = {
m.group("key")[1:-1]: m.group("value")[1:-1]
for m in WP_CONFIG_RE.finditer(contents)
}
db_conf = DBConfig.from_dict(wp_conf)
if db_conf.host in ("localhost", "127.0.0.1"):
db_conf.host = file_path.parent.name
check_mysql(db_conf, args)
HANDLERS_MAP = {
"**/*compose*.yml": handle_compose,
"**/*.env": handle_env,
"**/wp-config.php*": handle_wp_config,
}
def main(argv: Sequence[str] = sys.argv[1:]) -> None:
parser = create_parser()
args: NameSpace = parser.parse_args(argv)
lvl = max(logging.DEBUG, logging.WARNING - logging.DEBUG * args.verbose)
logger.setLevel(level=lvl)
logger.addHandler(ColorHandler())
for pat, hdlr in HANDLERS_MAP.items():
for path in args.path.glob(pat):
logger.debug(f"handle {path!s}")
hdlr(path, args)
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment