Skip to content

Instantly share code, notes, and snippets.

@lagru
Last active June 10, 2024 20:38
Show Gist options
  • Save lagru/b085c9c6c23a952dd2a1022cdf1a9398 to your computer and use it in GitHub Desktop.
Save lagru/b085c9c6c23a952dd2a1022cdf1a9398 to your computer and use it in GitHub Desktop.
A small convenience script to checkout pull requests as a local branch and clean up once done
#!/usr/bin/env python
"""Checkout GitHub pull requests locally.
A small convenience script to checkout pull requests as a local branch
and clean up once done.
Inspired by git-pr [1]_ from Stéfan van der Walt.
.. [1] https://github.com/stefanv/git-tools/blob/10cd994c8737e5192ada06d850ecbdbe3f223e34/scripts/git-pr
"""
import sys
import re
import argparse
import subprocess
import urllib.request
import json
import shlex
import traceback
from contextlib import contextmanager
BASE_URL = "git@github.com:"
GITHUB_REST_URL = "https://api.github.com/repos/"
BLACKLIST_BRANCH_NAMES = ("main", "master")
BLACKLIST_REMOTE_NAMES = ("origin", "upstream", "upstream-writeable")
def red(text: str) -> str:
"""Wrap `text` with bold red ANSII escape code."""
return f"\033[31;1m{text}\033[0m"
def blue(text) -> str:
"""Wrap `text` with bold red ANSII escape code."""
return f"\033[34m{text}\033[0m"
def bold(text) -> str:
"""Wrap `text` with bold ANSII escape code."""
return f"\033[1m{text}\033[0m"
def run(cmd: str, *args, check=True, show=False) -> str:
"""Run a command while handling printing.
`cmd` should only contain trusted input, while `args` might contain input
from untrusted sources, e.g. fetched data from the internet.
"""
if show is True:
joined_args = " ".join(shlex.quote(a) for a in args)
print(bold(f"$ {cmd} {joined_args}"))
cmd_parts = shlex.split(cmd)
if not cmd_parts or not cmd_parts[0]:
raise ValueError(f"cmd appears to be empty: {cmd_parts!r}")
result = subprocess.run(
cmd_parts + list(args),
check=check,
text=True,
stdout=subprocess.PIPE,
# Never use shell=True here to prevent shell injections,
# as args might contain input fetched from the web
shell=False,
)
output = result.stdout
if show is True and output:
print(result.stdout.strip())
return result.stdout
def parse_command_line():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"pr_number",
metavar="NUMBER",
help="Number of the pull request to check out",
)
parser.add_argument(
"-d",
"--done",
action="store_true",
help="Delete pull request branch and remote if they exist",
)
parser.add_argument(
"--remote",
metavar="REMOTE",
dest="ref_remote",
default="upstream",
help="Remote from which pull requests are taken " "(default: 'upstream')",
)
parser.add_argument(
"--fallback",
dest="fallback_branch",
metavar="BRANCH",
default="main",
help="When deleting, switch to this branch before doing so (default: 'main')",
)
kwargs = vars(parser.parse_args())
return kwargs
def remove_pr_branch(local_branch_name, *, fallback_branch, local_remote_name):
"""Remove the branch of a previously checked out pull request.
Given the `local_branch_name` of a previously checked out pull request,
delete it and potentially its `local_remote_name` if it's the last branch
associated with that remote. Switch to `fallback_branch` if the current
branch is the one to delete.
"""
if local_branch_name in BLACKLIST_BRANCH_NAMES:
raise RuntimeError("requested to delete blacklisted branch `main`")
if local_remote_name in BLACKLIST_REMOTE_NAMES:
raise RuntimeError(
f"requested to remove blacklisted remote `{local_remote_name}`"
)
run("git switch", fallback_branch, show=True)
run(f"git branch -D", local_branch_name, check=False, show=True)
branches = run("git branch -vv")
if local_remote_name not in branches:
run("git remote remove", local_remote_name, check=False, show=True)
@contextmanager
def handle_exceptions():
"""Handle (un)expected exceptions in `main()`."""
try:
yield
except (SystemExit, KeyboardInterrupt):
raise
except subprocess.CalledProcessError as error:
print(red(error))
sys.exit(1)
except Exception:
print(red(traceback.format_exc()), file=sys.stderr)
sys.exit(1)
def main(*, pr_number: str, done: bool, ref_remote: str, fallback_branch: str):
"""Run the script.
Check `parse_command_line` for the meaning of the parameters.
"""
repo_url = run("git config --get", f"remote.{ref_remote}.url")
match = re.match(r"^.*?(?P<owner>[\w-]+)/(?P<repo>[\w-]+)\.git$", repo_url)
ref_owner = match["owner"]
ref_repo = match["repo"]
pr_url = f"{GITHUB_REST_URL}{ref_owner}/{ref_repo}/pulls/{pr_number}"
with urllib.request.urlopen(pr_url) as response:
html = response.read()
pr_data = json.loads(html)
pr_title = pr_data["title"]
pr_html_url = pr_data["html_url"]
remote_name = pr_data["head"]["user"]["login"]
branch_name = pr_data["head"]["ref"]
print(blue(f"{pr_title}\n{pr_html_url}\n{remote_name}:{branch_name}"))
local_remote_name = f"_{remote_name}"
local_branch_name = f"pr/{pr_number}_{branch_name}"
if done is True:
remove_pr_branch(
local_remote_name=local_remote_name,
fallback_branch=fallback_branch,
local_branch_name=local_branch_name,
)
else:
remotes = run("git remote")
if local_remote_name not in remotes:
run(
"git remote add",
local_remote_name,
f"{BASE_URL}{remote_name}/{ref_repo}",
show=True,
)
run("git fetch", local_remote_name, branch_name, show=True)
branches = run("git branch")
if local_branch_name not in branches:
run(
"git checkout -b",
local_branch_name,
f"{local_remote_name}/{branch_name}",
show=True,
)
run("git config", f"branch.{local_branch_name}.description", pr_html_url)
else:
run("git switch", local_branch_name, show=True)
run("git merge", f"{local_remote_name}/{branch_name}", show=True)
if __name__ == "__main__":
with handle_exceptions():
kwargs = parse_command_line()
main(**kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment