Skip to content

Instantly share code, notes, and snippets.

@sgibson91
Last active June 25, 2020 21:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sgibson91/44a1c3a6bbf34257dbdbb621a98dab0d to your computer and use it in GitHub Desktop.
Save sgibson91/44a1c3a6bbf34257dbdbb621a98dab0d to your computer and use it in GitHub Desktop.
Use the GitHub API to update the base branch of open Pull Requests
"""
Script to pull a list of open Pull Requests for a GitHub repository and change
the base branch using the GitHub API. For this script to run successfully, you
must have write access to the target repository. If the repository belongs to
an organisation, then you must be a member of that organisation with write
access.
This script requires a GitHub Personal Access Token. To update Pull Requests in
public repos, the token only needs the `public_repo` scope. To update Pull
Requests in private repos, the token needs the full `repo` scope.
Author: Sarah Gibson (https://github.com/sgibson91)
Python version: >= 3.7 (developed with 3.8)
Package Requirements: requests
>>> pip install requests
"""
import sys
import json
import getpass
import logging
import requests
import argparse
# Setup log config
logging.basicConfig(
level=logging.DEBUG,
format="[%(asctime)s %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def parse_args(args):
"""Parse command line arguments
Args:
args (list): Command line arguments
Returns:
(Namespace object): The parsed command line arguments
"""
parser = argparse.ArgumentParser(
description="Update the base branch of open Pull Requests in a GitHub repository"
)
# Required arguments
parser.add_argument(
"repo_owner", type=str, help="The owner of the GitHub repository",
)
parser.add_argument(
"repo_name", type=str, help="The name of the GitHub repository",
)
# Optional arguments with defaults
parser.add_argument(
"-o",
"--old-branch",
type=str,
default="master",
help="The old base branch name",
)
parser.add_argument(
"-n", "--new-branch", type=str, default="main", help="The new base branch name",
)
return parser.parse_args()
def construct_api_url(repo_owner, repo_name):
"""Construct the API URL of the GitHub repository
Args:
repo_owner (str): The owner account (or organisation) of the target
repository
repo_name (str): The target repository
Returns:
(str): The GitHub API URL of the target repository
"""
return "https://api.github.com/repos/%s/%s/" % (repo_owner, repo_name)
def get_request(url, header, params=None):
"""Perform a HTTP GET request to a URL with headers and additional parameters.
Where responses are paginated (return with extra links), all results are
compiled from all pages.
Args:
url (str): The URL to make the GET request to
header (dict): The headers to be sent with the request
params (dict, optional): Any additional queries to send with the
request. Defaults to None.
Returns:
(dict): The JSON content response from the request
"""
resp = requests.get(url, headers=header, params=params)
if resp:
if resp.links:
full_resp = resp.json()
while "next" in resp.links.keys():
resp = requests.get(
resp.links["next"]["url"], headers=header, params=params
)
full_resp.extend(resp.json())
return full_resp
else:
return resp.json()
else:
logging.error("GET request to following URL failed: %s" % url)
sys.exit(1)
def patch_request(url, json, headers):
"""Make a HTTP PATCH request to a URL with a JSON body and headers
Args:
url (str): The URL to make the request to
json (dict): The body to send with the request
headers (dict): The headers to send with the request
"""
return requests.patch(url, headers=headers, json=json)
def get_open_prs(repo_url, old_branch, token):
"""Compile a list of all the open Pull Requests in the target repository
Args:
repo_url (str): The GitHub API URL for the target repository
old_branch (str): The target branch we want to change
token (str): A GitHub API token to authenticate with
Returns:
(list of ints): A list of all open Pull Request numbers
"""
# Construct request accompaniments
header = {"Authorization": "token %s" % token}
params = {
"state": "open",
"base": old_branch,
}
# Return open Pull Requests
pr_list = get_request(repo_url + "pulls", header, params=params)
logging.info("Total number of Pull Requests: %s" % len(pr_list))
# Compile a list of Pull Request numbers
pr_nums = []
for pr in pr_list:
pr_nums.append(pr["number"])
return pr_nums
def get_pr_details(repo_url, pr_num, token):
"""Compile the required information for updating a Pull Request
Args:
repo_url (str): The GitHub API URL for the target repository
pr_num (int): The number of the Pull Request to collect info for
token (str): A GitHub API token to authenticate with
Returns:
(dict): Title and body of the Pull Request
"""
header = {"Authorization": "token %s" % token}
pr_info = get_request(repo_url + "pulls/" + str(pr_num), header)
details = {
"title": pr_info["title"],
"body": pr_info["body"],
}
return details
def update_pr(repo_url, pr_num, payload, new_branch, token):
"""Use a HTTP PATCH request to update an open Pull Request
Args:
repo_url (str): The GitHub API URL for the target repository
pr_num (int): The number Pull Request to be updated
payload (dict): The info the Pull Request is to be updated with
new_branch (str): The new base branch of the Pull Request
token (str): A GitHub API token to authenticate with
"""
# Add additional info to the payload
payload["state"] = "open"
payload["maintainers_can_edit"] = True
payload["base"] = new_branch # This is redifining the base branch of the PR
header = {"Authorization": "token %s" % token}
resp = patch_request(repo_url + "pulls/" + str(pr_num), payload, header)
if resp.ok:
logging.info("Pull Request #%s successfully updated" % pr_num)
else:
logging.warning("Pull Request #%s was not updated\n%s" % (pr_num, resp))
def main():
"""Main function"""
args = parse_args(sys.argv[1:])
repo_url = construct_api_url(args.repo_owner, args.repo_name)
api_token = getpass.getpass("Please provide a GitHub API token: ")
open_prs = get_open_prs(repo_url, args.old_branch, api_token)
logging.info("Number of open Pull Requests to be updated: %s" % len(open_prs))
for pr in open_prs:
payload = get_pr_details(repo_url, pr, api_token)
update_pr(repo_url, pr, payload, args.new_branch, api_token)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment