Skip to content

Instantly share code, notes, and snippets.

@recht
Created December 28, 2023 23:44
Show Gist options
  • Save recht/65f2712b7684a5a5316be2dc485b0b7d to your computer and use it in GitHub Desktop.
Save recht/65f2712b7684a5a5316be2dc485b0b7d to your computer and use it in GitHub Desktop.
Find changed Bazel targets
import hashlib
import json
import os
import subprocess
import tempfile
def get_changed_targets():
commit = "HEAD^" if os.getenv("CI") else "origin/main"
files = exec("git", "diff", "--name-only", "--diff-filter=ACMRTUXB", commit + "..", "--").split("\n")
if not files:
return
prev_commit_files = []
changed_files = []
for file in files:
if file.endswith("BUILD.bazel"):
prev_commit_files.append("//" + os.path.dirname(file) + ":*")
elif file == "WORKSPACE":
prev_commit_files.append("//external:*")
elif file.endswith(".bzl"):
prev_commit_files.append("//external:*")
else:
changed_files.append(file)
changed_targets = set()
if prev_commit_files:
new_hashes = target_hashes(prev_commit_files, ".")
# create tempdir
with tempfile.TemporaryDirectory() as d:
print("Cloning to", d)
exec("git", "clone", ".", d)
exec("git", "checkout", commit, cwd=d)
prev_hashes = target_hashes(prev_commit_files, d)
exec("bazel", "shutdown", cwd=d)
diff = set(new_hashes.items()) - set(prev_hashes.items())
for d in diff:
changed_targets.add(d[0])
if changed_files:
fs = " ".join(["'" + f + "'" for f in changed_files])
query = f"set({fs})"
print("Finding changed targets:", query)
with tempfile.NamedTemporaryFile() as f:
f.write(query.encode())
f.flush()
print(exec("cat", f.name))
targets = exec("bazel", "query", "--keep_going", "--query_file", f.name).split("\n")
changed_targets.update(targets)
return changed_targets
def target_hashes(prev_commit_files, cwd):
with tempfile.NamedTemporaryFile() as f:
f.write("set(".encode())
f.write(" ".join(prev_commit_files).encode())
f.write(")\n".encode())
f.flush()
out = exec("bazel", "query", "--keep_going", "--query_file", f.name, "--output", "streamed_jsonproto", cwd=cwd)
res = {}
for data in json_stream(out):
if data.get("type") != "RULE":
continue
name = data["rule"].get("name")
if "//external:" in name:
name = "@" + next(a["stringValue"] for a in data["rule"]["attribute"] if a["name"] == "name") + "//..."
h = hashlib.sha1()
attrs = [a for a in data["rule"]["attribute"] if a["name"] != "generator_location"]
h.update(json.dumps(attrs, sort_keys=True).encode())
res[name] = h.hexdigest()
return res
def json_stream(json_string):
dec = json.JSONDecoder()
for line in json_string.split("\n"):
line = line.strip()
if not line:
continue
try:
data, end = dec.raw_decode(line)
if end != len(line):
raise ValueError(f"Extra characters after JSON data in line: '{line}'")
yield data
except json.decoder.JSONDecodeError as e:
raise ValueError(f"JSON Decode Error for line: '{line}': {e}")
def exec(*args, allow_fail=True, stream=False, **kwargs):
print("Running", " ".join(args))
with tempfile.NamedTemporaryFile() as f:
if not stream:
kwargs["stdout"] = f
p = subprocess.Popen(args, close_fds=True, **kwargs)
code = p.wait()
if code != 0:
if not allow_fail or code != 3:
if not stream:
f.seek(0)
print(f.read().decode())
print("Command failed with code", code)
os._exit(code)
if not stream:
f.seek(0)
out = f.read().strip()
if stream:
return
return out.decode()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment