Skip to content

Instantly share code, notes, and snippets.

@chrishronek
Created November 19, 2023 18:08
Show Gist options
  • Save chrishronek/8f94fb72e539e0459f48342218abbcc8 to your computer and use it in GitHub Desktop.
Save chrishronek/8f94fb72e539e0459f48342218abbcc8 to your computer and use it in GitHub Desktop.
GitHub Actions friendly script to check a dbt core project test/doc coverage compared to .sql written in the /models directory
import os
import sys
import uuid
import yaml
from jinja2 import BaseLoader, Environment
from sqlglot import exp, parse_one
from sqlglot.dialects import Redshift
from tabulate import tabulate
def set_multiline_output(name, value):
with open(os.environ["GITHUB_OUTPUT"], "a") as fh:
delimiter = uuid.uuid1()
print(f"{name}<<{delimiter}", file=fh)
print(value, file=fh)
print(delimiter, file=fh)
class DbtCoverage:
def __init__(self, project_dir: str, filtered_models: list = None):
self.project_dir = project_dir
self.filtered_models = filtered_models
def parse_sql(self, model_name: str, sql: str) -> str:
"""
:param model_name: the name of the dbt model
:param sql: The raw sql that has jinja variables
:return:
"""
# initialize the jinja parser
env = Environment(loader=BaseLoader())
# set some functions to insert placeholders or default values for dbt functions
def use_main_arg(value):
return value
def null_it(**kwargs):
return ""
def source_function(schema, table_name):
return f"{schema}.{table_name}"
def is_incremental_function():
return True
env.globals.update(
ref=use_main_arg,
source=source_function,
is_incremental=is_incremental_function,
env_var=use_main_arg,
config=null_it,
var=use_main_arg
)
# return the rendered SQL
return env.from_string(sql).render(this=model_name)
def get_sql_models(self):
directory = os.path.join(self.project_dir, "models")
sql_files = []
# read of the .sql files in the /models directory and make a list of dictionaries for them
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".sql"):
# If model isn't in filtered_models skip
if self.filtered_models and os.path.splitext(file)[0] not in self.filtered_models:
continue
# Use sqlglot to get model columns
with open(os.path.join(root, file)) as sql_file:
# read the raw SQL into a variable
sql_code = sql_file.read()
# parse the jinja stuff
rendered_sql = self.parse_sql(model_name=os.path.splitext(file)[0], sql=sql_code)
# gather the column names from the query
column_names = []
for expression in (
parse_one(rendered_sql, dialect=Redshift).find(exp.Select).args["expressions"]
):
if isinstance(expression, exp.Alias):
column_names.append(expression.text("alias"))
elif isinstance(expression, exp.Column):
column_names.append(expression.text("this"))
# Add to the list of models
sql_files.append(
{
"parent": os.path.basename(root),
"model_name": os.path.splitext(file)[0],
"columns_from_sql": column_names,
}
)
return sql_files
def get_yml_models(self):
directory = os.path.join(self.project_dir, "models")
models = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".yml"):
yaml_path = os.path.join(root, file)
with open(yaml_path) as yaml_file:
yaml_contents = yaml.safe_load(yaml_file)
yaml_models = yaml_contents.get("models", [])
for model in yaml_models:
# If model isn't in filtered_models skip
if self.filtered_models and model.get("name") not in self.filtered_models:
continue
has_table_description = model.get("description", "none_specified") != "none_specified"
columns = model.get("columns", [])
cols_descriptions = 0
col_tests = 0
for column in columns:
if column.get("description", "none_specified") != "none_specified":
cols_descriptions += 1
if len(column.get("tests", [])) != 0:
col_tests += 1
models.append(
{
"in_yml_meta": True,
"model_name": model.get("name"),
"has_table_description": has_table_description,
"columns_from_yml": [
{"col_name": column.get("name"), "col_descr": column.get("description", None)}
for column in columns
],
"column_descriptions": cols_descriptions,
"has_column_test": col_tests > 0,
}
)
return models
def combine_yml_and_sql(self):
sql_files = self.get_sql_models()
yml_models = self.get_yml_models()
# Create a dictionary for quick lookup based on 'name' in var_2
yml_models_lookup = {item["model_name"]: item for item in yml_models}
# Perform a left join
results = []
for item in sql_files:
model_name = item["model_name"]
matching_item = yml_models_lookup.get(
model_name,
{
"in_yml_meta": False,
"has_table_description": False,
"has_column_test": False,
"columns_from_yml": [],
"column_descriptions": 0,
"column_tests": False,
},
) # Default to an empty dictionary if model is not found in yml_models
item.update(matching_item)
results.append(item)
# Column comparison
for result in results:
columns_from_sql = result.get("columns_from_sql", [])
columns_from_yml = [col["col_name"] for col in result.get("columns_from_yml", [])]
columns_from_yml_w_descr = result.get("columns_from_yml", [])
# Find descriptions for the expected SQL columns
missing_col_descr = []
for col in columns_from_sql:
for col_w_descr in columns_from_yml_w_descr:
if col_w_descr.get("col_name") == col:
descr = col_w_descr.get("col_descr")
if descr is None:
missing_col_descr.append(col)
# Find columns that should exist in documentation but don't
yml_missing_cols = []
for col in columns_from_sql:
if col not in columns_from_yml:
yml_missing_cols.append(col)
# Find columns that are documented that don't exist
yml_extra_cols = []
for col in columns_from_yml:
if col not in columns_from_sql:
yml_extra_cols.append(col)
# Calculate the number of sql columns with descriptions
unwritten_descriptions = (len(yml_missing_cols) + len(missing_col_descr))
written_descriptions = len(columns_from_sql) - unwritten_descriptions
# Calculate the percentage of descriptions
description_coverage = 100 * (
written_descriptions / len(columns_from_sql) if len(columns_from_sql) != 0 else 0
)
result["col_description_coverage"] = description_coverage
result["tbl_description_coverage"] = result.get("has_table_description", False)
result["test_coverage"] = result.get("has_column_test", False)
result["missing_col_descr_count"] = len(missing_col_descr)
result["missing_col_descr"] = missing_col_descr
result["yml_missing_col_count"] = len(yml_missing_cols)
result["yml_missing_cols"] = yml_missing_cols
result["extra_yml_col_count"] = len(yml_extra_cols)
result["extra_yml_cols"] = yml_extra_cols
return results
def total_coverage_stats(self, joined_models: list):
total_models = 0
models_w_tests = 0
col_desc_coverages = []
for model in joined_models:
# gather total test coverage
total_models += 1
if model["test_coverage"] is True:
models_w_tests += 1
col_desc_coverages.append(model["description_coverage"])
total_test_coverage = 100 * (models_w_tests / total_models)
average_desc_coverage = sum(col_desc_coverages) / len(col_desc_coverages)
print(f"Project Test Coverage: {self._format_percentage_local(total_test_coverage)}")
print(f"Project Docs Coverage: {self._format_percentage_local(average_desc_coverage)}")
def total_coverage_stats_tbl(self, joined_models: list, gh_comment: bool = False):
if gh_comment:
tbl_list = [
(
f"{d['parent']}.{d['model_name']}",
self._format_boolean_gh(d["tbl_description_coverage"]),
self._format_percentage_gh(d["col_description_coverage"]),
self._format_boolean_gh(d["test_coverage"]),
)
for d in joined_models
]
else:
tbl_list = [
(
f"{d['parent']}.{d['model_name']}",
self._format_boolean_local(d["tbl_description_coverage"]),
self._format_percentage_local(d["col_description_coverage"]),
self._format_boolean_local(d["test_coverage"]),
)
for d in joined_models
]
table = tabulate(
tbl_list,
headers=["Model", "Tbl Doc", "Column Docs", "Test"],
tablefmt="github",
)
return table
def column_report(self, joined_models: list):
output = ""
for d in joined_models:
if d["extra_yml_col_count"] > 0 or d["yml_missing_col_count"] > 0 or d["missing_col_descr_count"] > 0:
output += "\n\n---"
output += f"{d['parent'].upper()}.{d['model_name'].upper()} COLUMN DISCREPANCIES"
output += "---"
if d["extra_yml_col_count"] > 0:
output += "\nThe following columns don't exist in SQL and can be removed from properties.yml:"
for col in d.get("extra_yml_cols", []):
output += f"\n - {col}"
if d["yml_missing_col_count"] > 0:
output += "\nThe following columns are completely missing in the properties.yml:"
for col in d.get("yml_missing_cols", []):
output += f"\n - {col}"
if d["missing_col_descr_count"] > 0:
output += "\nThe following columns are missing descriptions in properties.yml:"
for col in d.get("missing_col_descr", []):
output += f"\n - {col}"
return output
def column_report_md(self, joined_models: list) -> str:
"""
Generate the column report in GitHub-friendly Markdown format.
:param joined_models: List of joined models with column discrepancies
:return: GitHub-friendly Markdown formatted report
"""
markdown_report = ""
for d in joined_models:
if d["extra_yml_col_count"] > 0 or d["yml_missing_col_count"] > 0 or d["missing_col_descr_count"] > 0:
markdown_report += (
f"\n\n<details>\n\n <summary>{d['parent']}.{d['model_name']} column details</summary>\n\n"
)
if d["extra_yml_col_count"] > 0:
markdown_report += (
"\nThe following columns don't exist in SQL and can be removed from properties.yml:\n"
)
for col in d.get("extra_yml_cols", []):
markdown_report += f" - {col}\n"
if d["yml_missing_col_count"] > 0:
markdown_report += "\nThe following columns are completely missing in the properties.yml:\n"
for col in d.get("yml_missing_cols", []):
markdown_report += f" - {col}\n"
if d["missing_col_descr_count"] > 0:
markdown_report += "\nThe following columns are missing descriptions in properties.yml:\n"
for col in d.get("missing_col_descr", []):
markdown_report += f" - {col}\n"
markdown_report += "\n\n</details>"
return markdown_report
def _colorize(self, value, color):
# ANSI color codes for terminal output
color_codes = {"green": "\033[92m", "red": "\033[91m", "reset": "\033[0m"}
return f"{color_codes[color]}{value}{color_codes['reset']}"
def _format_boolean_local(self, value):
return self._colorize("✔" if value else "✘", "green" if value else "red")
def _format_boolean_gh(self, value):
return ":white_check_mark:" if value else ":x:"
def _format_percentage_local(self, value, threshold: int = 80):
if value < threshold:
return self._colorize(f"{round(value, 2)}%", "red")
else:
return self._colorize(f"{round(value, 2)}%", "green")
def _format_percentage_gh(self, value, threshold: int = 80):
if value < threshold:
return f":warning: {round(value, 2)}%"
else:
return f":medal_sports: {round(value, 2)}%"
if __name__ == "__main__":
# Get the command line arguments directly
args = sys.argv[1:]
# Extract project_dir and filtered_models from args
project_dir_index = args.index("--project_dir") if "--project_dir" in args else None
filtered_models_index = args.index("--filtered_models") if "--filtered_models" in args else None
# Get the values or set default values
project_dir = args[project_dir_index + 1] if project_dir_index is not None else None
filtered_models = args[filtered_models_index + 1].split() if filtered_models_index is not None else []
# Create an instance of DbtCoverage with the provided arguments
dbt_parse = DbtCoverage(project_dir=project_dir, filtered_models=filtered_models)
# Get the joined models
joined_models = dbt_parse.combine_yml_and_sql()
# this infers the script is running in GitHub
if os.getenv("GITHUB_OUTPUT", "UNAVAILABLE") != "UNAVAILABLE":
# Append to GITHUB_OUTPUT
stats_tbl = dbt_parse.total_coverage_stats_tbl(joined_models=joined_models, gh_comment=True)
column_report = dbt_parse.column_report_md(joined_models=joined_models)
report_output = f"{stats_tbl}\n{column_report}"
set_multiline_output("report_output", report_output)
# this is for local runs
else:
# Output total coverage stats table
stats_tbl = dbt_parse.total_coverage_stats_tbl(joined_models=joined_models)
column_report = dbt_parse.column_report(joined_models=joined_models)
report_output = f"{stats_tbl}\n{column_report}"
print(report_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment