Skip to content

Instantly share code, notes, and snippets.

@auvipy
Forked from charettes/compare.py
Created November 11, 2022 08:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save auvipy/3942923cd075b5a0a6c227d91c46a8a3 to your computer and use it in GitHub Desktop.
Save auvipy/3942923cd075b5a0a6c227d91c46a8a3 to your computer and use it in GitHub Desktop.
Django test suite SQL output compare
import argparse
import difflib
import os
import sys
from itertools import chain
import yaml
file_name_format = "{sha}:{vendor}.yml"
def compare(control_sha: str, feature_sha: str, vendor: str):
dir = os.path.dirname(__file__)
control_file_path = os.path.join(
dir, file_name_format.format(sha=control_sha, vendor=vendor)
)
with open(control_file_path) as control_file:
control = {
entry["test"]: entry["queries"][0]
for entry in yaml.load_all(control_file, Loader=yaml.SafeLoader)
if entry
}
feature_file_path = os.path.join(
dir, file_name_format.format(sha=feature_sha, vendor=vendor)
)
with open(feature_file_path) as feature_file:
feature = {
entry["test"]: entry["queries"][0]
for entry in yaml.load_all(feature_file, Loader=yaml.SafeLoader)
if entry
}
deltas = []
for test, control_queries in control.items():
if (feature_queries := feature.get(test)) is None:
continue
if control_queries != feature_queries:
deltas.append((test, control_queries, feature_queries))
for test, control_queries, feature_queries in deltas:
sys.stdout.writelines(
difflib.unified_diff(
list(
chain.from_iterable(
f"{query}\n".splitlines(True) for query in control_queries
)
),
list(
chain.from_iterable(
f"{query}\n".splitlines(True) for query in feature_queries
)
),
fromfile=f"{test}:{vendor}:{control_sha}",
tofile=f"{test}:{vendor}:{feature_sha}",
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("control_sha")
parser.add_argument("feature_sha")
parser.add_argument("-v", "--vendor", default="postgresql")
compare(**vars(parser.parse_args()))
diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py
index d505cd7904..d6f8c58acf 100644
--- a/django/db/backends/utils.py
+++ b/django/db/backends/utils.py
@@ -117,7 +117,9 @@ def debug_sql(
stop = time.monotonic()
duration = stop - start
if use_last_executed_query:
- sql = self.db.ops.last_executed_query(self.cursor, sql, params)
+ executed_sql = self.db.ops.last_executed_query(self.cursor, sql, params)
+ else:
+ executed_sql = sql
try:
times = len(params) if many else ""
except TypeError:
@@ -125,7 +127,7 @@ def debug_sql(
times = "?"
self.db.queries_log.append(
{
- "sql": "%s times: %s" % (times, sql) if many else sql,
+ "sql": "%s times: %s" % (times, executed_sql) if many else executed_sql,
"time": "%.3f" % duration,
}
)
@@ -137,7 +139,8 @@ def debug_sql(
self.db.alias,
extra={
"duration": duration,
- "sql": sql,
+ "sql": executed_sql,
+ "raw_sql": sql,
"params": params,
"alias": self.db.alias,
},
diff --git a/django/test/runner.py b/django/test/runner.py
index fb4d77ed60..46f841a5e5 100644
--- a/django/test/runner.py
+++ b/django/test/runner.py
@@ -43,6 +43,49 @@
tblib = None
+import subprocess
+
+sha = (
+ subprocess.check_output(["git", "rev-parse", "--short", "HEAD^"])
+ .decode("ascii")
+ .strip()
+)
+
+
+class RecorderHandler(logging.Handler):
+ def __init__(self, test):
+ self.test = test
+ self.queries = []
+ super().__init__(logging.DEBUG)
+
+ def handle(self, record):
+ query = record.raw_sql
+ if (
+ query.startswith("SAVEPOINT")
+ or query.startswith("RELEASE SAVEPOINT")
+ or query.startswith("EXPLAIN")
+ ):
+ return
+ self.queries.append(sqlparse.format(query, reindent=True, keyword_case="upper"))
+
+ def flush(self):
+ import os
+ import yaml
+
+ try:
+ os.mkdir("tests/.rsql")
+ except FileExistsError:
+ pass
+ vendor = connections["default"].vendor
+ with open(f"tests/.rsql/{sha}:{vendor}.yml", "a") as file:
+ yaml.dump(
+ {"test": str(self.test), "queries": [self.queries]},
+ file,
+ sort_keys=False,
+ )
+ file.write(f"---\n")
+
+
class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
self.logger = logging.getLogger("django.db.backends")
@@ -54,15 +97,19 @@ def startTest(self, test):
self.debug_sql_stream = StringIO()
self.handler = logging.StreamHandler(self.debug_sql_stream)
self.logger.addHandler(self.handler)
+ self.record_handler = RecorderHandler(test)
+ self.logger.addHandler(self.record_handler)
super().startTest(test)
def stopTest(self, test):
super().stopTest(test)
self.logger.removeHandler(self.handler)
+ self.logger.removeHandler(self.record_handler)
if self.showAll:
self.debug_sql_stream.seek(0)
self.stream.write(self.debug_sql_stream.read())
self.stream.writeln(self.separator2)
+ self.record_handler.flush()
def addError(self, test, err):
super().addError(test, err)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment